{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Kuka\n",
    "\n",
    "---\n",
    "\n",
    "You are welcome to use this coding environment to train your agent for the project.  Follow the instructions below to get started!\n",
    "\n",
    "### Start the Environment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make sure that you're in the right virtual environment and the right python version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python 3.6.9 :: Anaconda, Inc.\n",
      "Requirement already satisfied: pybullet in /home/monster/anaconda3/envs/python36/lib/python3.6/site-packages (2.6.1)\n",
      "Requirement already satisfied: tensorboardX in /home/monster/anaconda3/envs/python36/lib/python3.6/site-packages (2.0)\n",
      "Requirement already satisfied: six in /home/monster/anaconda3/envs/python36/lib/python3.6/site-packages (from tensorboardX) (1.13.0)\n",
      "Requirement already satisfied: protobuf>=3.8.0 in /home/monster/anaconda3/envs/python36/lib/python3.6/site-packages (from tensorboardX) (3.11.2)\n",
      "Requirement already satisfied: numpy in /home/monster/anaconda3/envs/python36/lib/python3.6/site-packages (from tensorboardX) (1.17.4)\n",
      "Requirement already satisfied: setuptools in /home/monster/anaconda3/envs/python36/lib/python3.6/site-packages (from protobuf>=3.8.0->tensorboardX) (42.0.2.post20191203)\n"
     ]
    }
   ],
   "source": [
    "!python --version\n",
    "!pip install pybullet\n",
    "!pip install tensorboardX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current_dir=/home/monster/anaconda3/envs/python36/lib/python3.6/site-packages/pybullet_envs/bullet\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "from collections import deque\n",
    "import timeit\n",
    "from datetime import timedelta\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import random\n",
    "from PIL import Image\n",
    "from tensorboardX import SummaryWriter\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as T\n",
    "import torch.optim as optim\n",
    "\n",
    "from pybullet_envs.bullet.kuka_diverse_object_gym_env import KukaDiverseObjectEnv\n",
    "from gym import spaces\n",
    "import pybullet as p\n",
    "\n",
    "env = KukaDiverseObjectEnv(renders=False, isDiscrete=False, removeHeightHack=False, maxSteps=20)\n",
    "env.cid = p.connect(p.DIRECT)\n",
    "action_space = spaces.Box(low=-1, high=1, shape=(5,1))\n",
    "\n",
    "# if gpu is to be used\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Actor-Critic implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def build_hidden_layer(input_dim, hidden_layers):\n",
    "    \"\"\"Build hidden layer.\n",
    "    Params\n",
    "    ======\n",
    "        input_dim (int): Dimension of hidden layer input\n",
    "        hidden_layers (list(int)): Dimension of hidden layers\n",
    "    \"\"\"\n",
    "    hidden = nn.ModuleList([nn.Linear(input_dim, hidden_layers[0])])\n",
    "    if len(hidden_layers)>1:\n",
    "        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])\n",
    "        hidden.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])\n",
    "    return hidden\n",
    "\n",
    "class ActorCritic(nn.Module):\n",
    "    def __init__(self,state_size,action_size,shared_layers,\n",
    "                 critic_hidden_layers=[],actor_hidden_layers=[],\n",
    "                 seed=0, init_type=None):\n",
    "        \"\"\"Initialize parameters and build policy.\n",
    "        Params\n",
    "        ======\n",
    "            state_size (int,int,int): Dimension of each state\n",
    "            action_size (int): Dimension of each action\n",
    "            shared_layers (list(int)): Dimension of the shared hidden layers\n",
    "            critic_hidden_layers (list(int)): Dimension of the critic's hidden layers\n",
    "            actor_hidden_layers (list(int)): Dimension of the actor's hidden layers\n",
    "            seed (int): Random seed\n",
    "            init_type (str): Initialization type\n",
    "        \"\"\"\n",
    "        super(ActorCritic, self).__init__()\n",
    "        self.init_type = init_type\n",
    "        self.seed = torch.manual_seed(seed)\n",
    "        self.sigma = nn.Parameter(torch.zeros(action_size))\n",
    "\n",
    "        # Add shared hidden layer\n",
    "        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)\n",
    "        self.bn1 = nn.BatchNorm2d(16)\n",
    "        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)\n",
    "        self.bn2 = nn.BatchNorm2d(32)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)\n",
    "        self.bn3 = nn.BatchNorm2d(32)\n",
    "\n",
    "        # Number of Linear input connections depends on output of conv2d layers\n",
    "        # and therefore the input image size, so compute it.\n",
    "        def conv2d_size_out(size, kernel_size = 5, stride = 2):\n",
    "            return (size - (kernel_size - 1) - 1) // stride  + 1\n",
    "        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(state_size[0])))\n",
    "        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(state_size[1])))\n",
    "        linear_input_size = convh * convw * 32\n",
    "        self.shared_layers = build_hidden_layer(input_dim=linear_input_size,\n",
    "                                                hidden_layers=shared_layers)\n",
    "\n",
    "        # Add critic layers\n",
    "        if critic_hidden_layers:\n",
    "            # Add hidden layers for critic net if critic_hidden_layers is not empty\n",
    "            self.critic_hidden = build_hidden_layer(input_dim=shared_layers[-1],\n",
    "                                                    hidden_layers=critic_hidden_layers)\n",
    "            self.critic = nn.Linear(critic_hidden_layers[-1], 1)\n",
    "        else:\n",
    "            self.critic_hidden = None\n",
    "            self.critic = nn.Linear(shared_layers[-1], 1)\n",
    "\n",
    "        # Add actor layers\n",
    "        if actor_hidden_layers:\n",
    "            # Add hidden layers for actor net if actor_hidden_layers is not empty\n",
    "            self.actor_hidden = build_hidden_layer(input_dim=shared_layers[-1],\n",
    "                                                   hidden_layers=actor_hidden_layers)\n",
    "            self.actor = nn.Linear(actor_hidden_layers[-1], action_size)\n",
    "        else:\n",
    "            self.actor_hidden = None\n",
    "            self.actor = nn.Linear(shared_layers[-1], action_size)\n",
    "\n",
    "        # Apply Tanh() to bound the actions\n",
    "        self.tanh = nn.Tanh()\n",
    "\n",
    "        # Initialize hidden and actor-critic layers\n",
    "        if self.init_type is not None:\n",
    "            self.shared_layers.apply(self._initialize)\n",
    "            self.critic.apply(self._initialize)\n",
    "            self.actor.apply(self._initialize)\n",
    "            if self.critic_hidden is not None:\n",
    "                self.critic_hidden.apply(self._initialize)\n",
    "            if self.actor_hidden is not None:\n",
    "                self.actor_hidden.apply(self._initialize)\n",
    "\n",
    "    def _initialize(self, n):\n",
    "        \"\"\"Initialize network weights.\n",
    "        \"\"\"\n",
    "        if isinstance(n, nn.Linear):\n",
    "            if self.init_type=='xavier-uniform':\n",
    "                nn.init.xavier_uniform_(n.weight.data)\n",
    "            elif self.init_type=='xavier-normal':\n",
    "                nn.init.xavier_normal_(n.weight.data)\n",
    "            elif self.init_type=='kaiming-uniform':\n",
    "                nn.init.kaiming_uniform_(n.weight.data)\n",
    "            elif self.init_type=='kaiming-normal':\n",
    "                nn.init.kaiming_normal_(n.weight.data)\n",
    "            elif self.init_type=='orthogonal':\n",
    "                nn.init.orthogonal_(n.weight.data)\n",
    "            elif self.init_type=='uniform':\n",
    "                nn.init.uniform_(n.weight.data)\n",
    "            elif self.init_type=='normal':\n",
    "                nn.init.normal_(n.weight.data)\n",
    "            else:\n",
    "                raise KeyError('initialization type is not found in the set of existing types')\n",
    "\n",
    "    def forward(self, state):\n",
    "        \"\"\"Build a network that maps state -> (action, value).\"\"\"\n",
    "        def apply_multi_layer(layers,x,f=F.leaky_relu):\n",
    "            for layer in layers:\n",
    "                x = f(layer(x))\n",
    "            return x\n",
    "\n",
    "        state = F.relu(self.bn1(self.conv1(state)))\n",
    "        state = F.relu(self.bn2(self.conv2(state)))\n",
    "        state = F.relu(self.bn3(self.conv3(state)))\n",
    "        state = apply_multi_layer(self.shared_layers,state.view(state.size(0),-1))\n",
    "\n",
    "        v_hid = state\n",
    "        if self.critic_hidden is not None:\n",
    "            v_hid = apply_multi_layer(self.critic_hidden,v_hid)\n",
    "\n",
    "        a_hid = state\n",
    "        if self.actor_hidden is not None:\n",
    "            a_hid = apply_multi_layer(self.actor_hidden,a_hid)\n",
    "\n",
    "        a = self.tanh(self.actor(a_hid))\n",
    "        value = self.critic(v_hid).squeeze(-1)\n",
    "        return a, value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Examine the state and action spaces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of agents: 1\n",
      "Size of each action: 3\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2deZRldXXvP/uONY89D9DQNCAoYABBo4ZgNK2JD0wMhrzwMMvExKcrccXnEPLWcnjJiq7nlPVMjBhUEo1DUASNEyAOKBEZmrGBbuiGHqqru7q65ukOv/fHPU1u1d6/6uqurqpuzv6sVavu3fd3zvmd3zn7nnO/Z//2lhACjuM8/8ksdQccx1kc3NkdJyW4sztOSnBnd5yU4M7uOCnBnd1xUoI7+wmEiLxZRO5a6n6cSIjIBhEJIpJb6r6c7KTG2UVkp4iMi8hI3d+nlrpfS42IXCYiuxdw/R8QkS8u1PqduZO2b8vXhxBuX+pOnGyISC6EUF7qfiwEz+d9m0lqruyzISKfFpGv173/iIjcITU6ReTbInJARA4lr9fVtf2RiPyNiPw8uVv4loh0i8iXRGRIRH4pIhvq2gcR+XMReVpE+kTk/4qIeRxE5GwRuU1E+kXkCRG5apZ9aBeRG0SkR0T2JH3KHmH/moHvAmvq7nbWJFfjm0TkiyIyBLxZRF4iIneLyECyjU+JSKFunefW9bVXRK4Tkc3AdcCbknU/OIe+ZkXko8nYPA381hGO3XuTdQwnY/SquvVcJyJPJZ/dJyLr647B20VkG7DtSGMtIsWkT88m+/ZPItKYfHaZiOwWkXeJyP5kn/5otj4vGSGEVPwBO4HfiHzWBDwJvBl4BdAHrEs+6wZ+N2nTCvw78M26ZX8EbAc2Au3AY8m6foPandO/AJ+vax+AO4Eu4JSk7R8nn70ZuCt53QzsAv4oWc+Lk36dE9mHm4HPJMutAO4B/nQO+3cZsHvGuj4AlIArqV0QGoELgUuTvmwAtgLvTNq3Aj3Au4CG5P0ldev64lH09c+Ax4H1yRjdmYxZztjns5IxWpO83wBsTF6/G3g4aSPA+UB33TG4LVl/45HGGvgEcGvSvhX4FvB3deNXBj4E5IHXAWNA51Kf82q8lroDi7ajNWcfAQbq/v6k7vNLgH7gGeDqWdZzAXCo7v2PgL+ue/8x4Lt1718PbKl7H4DNde//J3BH8vrN/Jezvwn46YxtfwZ4v9GnlcAk0Fhnuxq480j7R9zZf3KE8XwncHPdth6ItPsAdc5+pL4CPwT+rO6z1xB39jOA/dS+WPMzPnsCuCLSpwBcXvc+OtbUvihGSb5Eks9eCuyoG7/x+v4lfbp0qc/5mX9p+81+ZYj8Zg8h/CK5bVwBfO2wXUSaqH2zbwY6E3OriGRDCJXkfW/dqsaN9y0zNrer7vUzwBqjS6cCl4jIQJ0tB/xrpG0e6BGRw7ZM/XZi+zcL9X1ERM4EPg5cRO1OIQfcl3y8HnhqDuucS1/XoMfHJISwXUTeSe0L5VwR+T7wlyGEvXPoU/02Zhvr5dT29766/gqQrWt7MEz/3T+GPuZLjv9mTxCRtwNFYC/wnrqP3kXtVvCSEEIb8MrDi8xjc+vrXp+SbHMmu4AfhxA66v5aQghvi7SdBJbVtW0LIZx7uMEs+xeb9jjT/mlqt9ebknG4jv8ag13A6XNcz5H62oMenyghhH8LIbycmsMG4CN129k426Iz+hQb6z5qX9jn1n3WHkI44Zz5SLiz89xV62+APwSuAd4jIhckH7dSO9gDItJF7dZuvrw7Ef7WA38BfNVo823gTBG5RkTyyd/FIvKCmQ1DCD3AD4CPiUibiGREZKOI/Noc9q8X6BaR9iP0uRUYAkZE5Gyg/kvn28BqEXlnIma1isgldevfcFiEPFJfqd11/LmIrBORTuB9sQ6JyFkicrmIFIEJasepmnz8z8D/EZFNUuM8EemOrCo61iGEKvBZ4BMisiLZ7loR+c0jjNcJR9qc/Vsy/Tn7zVIL1vgi8JEQwoMhhG3Urlr/mpxEn6Qm4vQB/wl87zj04xZqt8BbgP8AbpjZIIQwTO336u9Tuxrvo3bVKkbW+T+AAjWB8BBwEzUHnHX/QgiPA18Gnk6UdusnBcD/Av4AGKZ28j/3BZX09dXU9Il91BTuX08+/vfk/0ERuX+2viaffRb4PvAgcD/wjUh/SMbiw9SOzT5qP1H+Kvns49S+OH5A7UvqBmrHUTGHsX4vNRH2P5OnE7dTu9s7qZBEUHAWCREJ1G6Fty91X5x0kbYru+OkFnd2x0kJfhvvOCnBr+yOkxLmFVSTxD7/PbUAg38OIXx4tvbtrU1hZXfHXFduGO27EDHaRsLNI8Tubqz1zm37MaJtj+YGaz5P+Gfpgzlmxp3fVKlkLj8+VdG2SXuOSSajt3U04xgbrvamgrI1FPPKVqrYazgwXFW2ELkmVgz36Zq0Y4CKYUrZMpH9zZjnmNlUnQq7RwIHJ4LZ+pidPZm48A/UHrnsBn4pIreGEB6LLbOyu4P/97//eJotk8mabcU4GYLokwkgm9W7kcs32G2N41aN/JTJGn3LGkciV9AnGEBA9zefizw5C/okk4hXW0MWqvYJWUWvt7HB7kM+q52iXNbOumtfr7IBPLJzQNke3nHIbNva0qRs1hcAQDard3iqbJ8Lmy8+TdlesGGZsu0bsL+E/unHk8o2iX0ujWT1et/0hBXzBKdXn1W2xoLtfk2GuWi7CcXs9HPkNbdM2A2Z3238S4DtIYSnQwhTwFeAK+axPsdxFpD5OPtapscX705s0xCRt4rIvSJy7+Dw2Dw25zjOfFhwgS6EcH0I4aIQwkXtrfrWzXGcxWE+At0epk9YWJfYZiXMUBrE+D0GIMZvY+t3fMxercbEPGN7on/XAgTjd3SlqodMKvbylt5Vqdi/FS1xKhMZG+u3ddH6oQdkDZHioW369yPA4Ki+8+o/NKps377LlmVWLFuubK2tbZF+6X2rVOzf4c88s1PZJiPCn1ys583kjLHtbjUX5+9+r1nZfv6UfXz/4867la2hNGy2zeaN4xsT3SyBzm6q2s6mcc7nyv5LYJOInCa1jCW/T22Cv+M4JyDHfGUPIZRF5B3UJi1kgc+FEB49bj1zHOe4Mq/n7CGE7wDfOU59cRxnAfEIOsdJCe7sjpMSFjUHXUAHioWIAhvQam0mUhREjOjAULXV2ooRUVbI2xFwpvhviPzVqr0PxqaQvP39WsjpfcsXbDV+eGxc2f79tgfNts1NOvrrBz97yGy7u1dHu5256Uxle+1vvtZc/umnn1a2WAjsvn37lK2v74DZ9uc//5mylcu2Qt6Y0Wr4X7/tamVrb7fl+DEdQMfFZ9qPjJt+rPXoyoS9D6GooxZjavxC4Vd2x0kJ7uyOkxLc2R0nJbizO05KWFSBToDMjBjSij311uxYNmu3tUJCSyVbNDMFI4mE1lrfhVa0bWTuvBhqXkuLmeCUm76rQy937Ok325arWpx64DEtjoE9Ng2GWATQ0aZDRVevXKFslhAHkM/rMN6uri6z7YMPblG2HTvs9RYbtMjYGBnz7/34AWV7239/vbJ1dNhp3y2x1RpvgKo1XTkmuhmnWCxJlGWPRH+rtrMlnvIru+OkBHd2x0kJ7uyOkxLc2R0nJbizO05KWFw1XoTcjCR7pUim0hDMzJB2WyOuNZa500qQEIxQ1dp6tc3KrJqNJNV4cqcuzvqNH9phrSOjOk5z1IrdBIIRh9vabId0WgkwpqZ0plOAipGEwxqvyUl7+f37tZp++ul2IdVcTo/Z6IhOlAGYT0sykWSN1hOFf7nlJ8r27rdohR6gYGSiDTEp3LDH1HBLz4+Wz7WU+0hb+zmBjV/ZHScluLM7TkpwZ3eclODO7jgpYb7ln3YCw0AFKIcQLjrCAqrqRzYylzwYkkQ5on40GFpcsckWrEZGdcWMvQfsqiVNRljpDbfepWy9B3QlFMBM9VmOVDKxhLRCQ+TwGOMQE4YsgS0fESStLL05o+3klF11ZMczO5RtzRpVSgCAYlGHDRcNcQygu7tb2c4//1x7vUZ/f/Sze5Rt5zN2IuQvffIvlW0yIvZOGkJnJFExVgLiSFJiqsbm4mLe3OuGHQ81/tdDCH3HYT2O4ywgfhvvOClhvs4egB+IyH0i8larwbTyT0OR56iO4yw4872Nf3kIYY+IrABuE5HHQwjTIhhCCNcD1wOcdfq6oylM7DjOcWReV/YQwp7k/37gZmqVXR3HOQGZT332ZiATQhhOXr8G+NBsy4QQKM2o953L2QpsIafV6XzR7u7Wp3uUrRJRKbc/q2uLf//ndiEbK1lGY1EnUoiFU5aNDLdWMgmw82dUI8p91ahBl8/ZTzWsZArlSEZfCdref0g/qchFwoObG7XCPjxi1z771Ze9XNnO2Hiq2Xbo0EFly8aeKBjKec540tE/NGIu/+i2Xcr2onM3mW1bN7xA2Ub2bDXblq2kGJH73LKR0CUXUe7LMw7FbLfO87mNXwncnAxuDvi3EML35rE+x3EWkPnUensaOP849sVxnAXEH705TkpwZ3eclLCo89kzGaG5cbrA1XNw0Gz7nw/r0MvWZjsz6w/veUzZRsbtOdfWPPdipMyShRV+GsMShmLpR63w4FgopJnNNpLV1EqMauh7AEyV9ZgNDehyRqtWrjKXP2vjBmVrarGzuDa36vJL64v23Pf7enRegP5+LcoCtLW3m/aZHOizQ6Tf++EvKNsH//IPzLYXvl6Xlbr/UR1ODVAa1OG5ZfP80KIbxMW83FE8zPYru+OkBHd2x0kJ7uyOkxLc2R0nJbizO05KWFQ1/uDACJ//1s+m2Q4NjZltt+3er2z5jK1eFvJ6N4qR7KPBkKet+l4AwQhbzOR1HyLl6kwieRAQcyX2d7H1RKFU0qG5YNd1W9HdFumE3l5Xh25bLtvJKyatTMEj9tOWnr1anV6xfJnZdmJiXNn27NUKPcDeXn3elMu6X8WiHV68Y49e/u4HnjTbXnrhi5RtfMJ+CoRxihUjT0Wyhj0XOW9mJsDwWm+O47izO05acGd3nJTgzu44KWFRBbqxiRJbnpwuzMTKNLU0GvPGI3GeR6GPmXPBrZBSgIyxZiuEVSLfmXa4q91bax52TMxb3q1DQpuaddkjsOfPVyp2ya2ikem3ag6O3bFCXucmOHBgn922UYfLxvZ3bFwLdBecf57Z9uHHtJg2Pj6kbHmjrwAtzfq8u+k7dgjsBWefrmwvuvbdZtsnrv9rZStF0svmMkZm5cjgzAyjnS161q/sjpMS3NkdJyW4sztOSnBnd5yUcESBTkQ+B/w2sD+E8MLE1gV8FdgA7ASuCiHYE4Snr4uGGdFulZg6dhRCWMUQ3UIkWaMlWBUiiSxNIco0xSRCbW9qtCO3ujt1pFo+UhqraiSGzGfsfbASHYbIOI5PGPXgjePQ399vLt9uzCUvTdlioCVIFgpaHANYd4oWwgaNefYABSNi0BKtZpYhe65fFd16IpIb4d0fvkHZ/u2T7zHbTpT0eiPVrszkktmI9JbPTB/H+Qp0XwA2z7C9D7gjhLAJuCN57zjOCcwRnT0p+jDzq/wK4Mbk9Y3Alce5X47jHGeO9Tn7yhDC4bxA+6illTZJykK9FaBYiOQ2dxxnwZm3QBdqkSPRnwohhOtDCBeFEC7KRQIZHMdZeI7V2XtFZDVA8l/PC3Qc54TiWG/jbwWuBT6c/L9lrgvqSkn2TYEZlmplVcX+xipHJvaWyoZyb002BvJGiaEG46dI3phPD9BthLVmYvtglLsyoiYB+0nDRMmeY14xQjLHxuwcAlbIraWmZyJZUa3js/6UU8y2driqvcOdnV3KdqjfVuMnJvQ45I0SY7HMvWT0U4JsRLkvG2P7y0e2mW3PPvtCZRt7/B6zrVVeKx+5LM8U+ec1n11EvgzcDZwlIrtF5C3UnPzVIrIN+I3kveM4JzBHvLKHEHRy7BqvOs59cRxnAfEIOsdJCe7sjpMSFnU+O1hhkrFQ07l/D1khtzERqa1Ni1CFSFhq0Ugu2dSkQzolIuBYws7YqC2OZQ2Rr9Godw52ffX9++wHIoVGHT7a0d5htrXEndZWHcYbDBELoMHIQTA6rOeiAwwO7Z7TtgCmpnS4au+BPrPtoUE9d71gCHRILEzbMtr7mzdUs49ef7PZ9m2/8wplWzf2M6MlFPP6mOUiVcdmHgqfz+44jju746QFd3bHSQnu7I6TEtzZHSclLLIaH6jMTKYQi+8zBNCWBls1b2jSKq6VpAJAjEQXhUgpoMEhXbook7USUjSZy08ZySB6euyyRatWrVK2A0NaWQZb/W9utvtgZdPtP2TnGeno7FQ2K/lFLlKGa7B/WNlW/IrZlMHt2nYworA//vgjyjY0pLcFtvJuPSypVm2F3U50Ybc1M+/qeHAAvvmTB5Xt42/8PbPt3tu/rmzZFnsSmcj07UU2D/iV3XFSgzu746QEd3bHSQnu7I6TEpYgXHa6grDMyKoK0NBghKVG1pk15oIPDdt1wYtFHYJamjKyqgKd7VqwKpV026oRvgowMjKijRFB8lAkY6tFzphnn4+IjBNjen53m5EFFuBgnxbIrPn3y1esMJcfNkJjW9bo0E+AF27uVravvO8+s21jQR+zZd16jjvA0PAuZRNDUIxOZzeFXbtxkxHOPI59LvUN6+PwmdseM9tuNk70iZl1nhKyM8LPvT674zju7I6TFtzZHScluLM7TkqYSw66z4nIfhF5pM72ARHZIyJbkr/XLWw3HceZL3NR478AfAr4lxn2T4QQPno0Gyvk86xfYyu5c2H/fjtBgxWu2tFhK84jw1oh7+3tNduuWKlrXwwbIayTU/vM5RubtBKdjSTVKBsZY48mi2tjo07KATAwYPS376DZ1iJvPL24/14dvgpw3ublynbGr60x20pB78N/e/9Gs23OSCLy6O12dtlduy3lXW9LIskryqWy0dZ+DjRS0olIrCclABUjAcft2+wacu2bdG271zXY+zs5Nf0JyLySV0TKPzmOc5Ixn9/s7xCRh5LbfP1A2nGcE4pjdfZPAxuBC4Ae4GOxhiLyVhG5V0TunZy0b1scx1l4jsnZQwi9IYRKCKEKfBZ4ySxtn6v1VoxEeTmOs/AcU7isiKyuq+L6BsBWbGYQQlACSCyD6mRJ3wWUSroUEUDvsCXc2VJFnxES2tFh/wrp6NBZWEeN7LAdHbY4VjTm349G5mFPGXc9MbFl0pgnPz5ui4RWkPHYuB3SaWWt7crpfVi2zA5VzeRblK1wjxa8ANpHtUCWv9wex6lx3faSq2zhb7BH79tDd+ixCRlboLPCZSP6HGXjdKwa8/9rKzFsVft8Hj/1pco2WXjSXu+OLdPehlkmtB/R2ZPyT5cBy0RkN/B+4DIRuYDa+bgT+NMjrcdxnKXlWMs/3bAAfXEcZwHxCDrHSQnu7I6TEtzZHSclLGryipoaP12BLEXCC3t7tcLe0mSrtdWKVlYrhg2g3ahzls/bmTsnJqzEDzrZxsSkXc/skJHFtVjUSTkAgvGkoWxlLwWmJrXiPGmEeQJMGfZVq+2Q5WYj7HjtWh0yPJmxFd/2x/STijM+M2C27RrR+/a1z+uEFgBrX6aPT+8W+6nG6H6thmdzur/ViMSeER1uGzuXrMQeUS3c+KCxYJ8L3/vh3cq2q9k+x65eOz0kuxxJngF+ZXec1ODO7jgpwZ3dcVKCO7vjpITFzS4rIDMywfbut+fptjbr0MvW1lazbYthzxllgMCeu56NiIRWdtiqkb6zsdEWWiplLez09tr7Wyzque+TU7boVjAExY0bTzPbBnQfOtvsuf6Fgg6NHStpkTK/xxbHXnuPFhk3VmxRdbRF9+vS99vrLf+B3t+ta+1j1ljWc+rPOluLslufeMJcPhjHNxaCao1tCDHhz7CLvd6ikQ334THtDwCvmRyd9r4S2T74ld1xUoM7u+OkBHd2x0kJ7uyOkxLc2R0nJSyqGi8iZLPTldXOSBbYcSNUdXDQrt82ZdVqiymohgJaqcRUb62QDw3pPhzo02GxALmsHl5LzQdo69Chosu67Dp42YxWXJuabdU7ayRjGBq0Q1gP9uknBc1GoorG/lFlA1g2oROR7MzrEFqAtqoem1f26nBdgN5P6eO77cV2irPH2/WxLBj5JNra7LHtH9DHN5ZdNp/VTy9KVftcssJwY4kurO01GMo/wJcPTH/60F/uMduBX9kdJzW4sztOSnBnd5yU4M7uOClhLgkn11Mr/bSS2qzc60MIfy8iXcBXgQ3Ukk5eFUKwlar/Whf5GeGywRCQwA5h7YuULZqY0CLQqlWrI2212NPTY5d/KhlzwRusbLiRCMWmZt32nHPPMNsGQ3/JZOyxmTTEyx1PPWW2tcS8WFkpa+57/6BRPiqSmfUdF2vR7JrddojzGYNanBrAFvNWTBoiY589bzu0a9GssUkLf+tX2udHW78uftQfuSaOWmJrTBg2xiyi1ZrZbK2SXwCHpqbby9X5hcuWgXeFEM4BLgXeLiLnAO8D7gghbALuSN47jnOCMpdabz0hhPuT18PAVmAtcAVwY9LsRuDKheqk4zjz56h+s4vIBuDFwC+AlXWFIvZRu823lvHyT45zAjBnZxeRFuDrwDtDCNN+yIXavEDzF4iXf3KcE4M5RdCJSJ6ao38phPCNxNx7uAyUiKwG7OLpdVTKFQZmCD7ZSHTS8LCO0opN1c3n9XzyrU/Y5XIsIazJEHAAOjr1POjVa9bqdq32fHarXNWhfjt6raVFz1e2yjwBHOjTQ20lPwSYnNI7HKshPlnSba2IvzWr7NJLVh31rxZsUXV0Uo/NKuxxvHRYC51bzrYj4NqMikoFowzX6QV9bAF+a68uD/blTjvK8+YW3a/miKhqIZH0lFUz0WhknnxmptAZL/90xJ5JLXbvBmBrCOHjdR/dClybvL4WuOVI63IcZ+mYy5X9V4FrgIdF5HAVueuADwNfE5G3AM8AVy1MFx3HOR7MpdbbXUSfJPOq49sdx3EWCo+gc5yU4M7uOClhUeezT06V2PnM3mk2iWTYtMIOJ61569jzf9euW2e2zed029YWW23t7NaK7+5nn1W2YtYuW1Su6JDQiXE7JNR6IhDLejsxriXn2JOKgvGkwkh6G+3D6hXLlC1fsDP3drTp0Ng1K/XyAHv36ScK42X7XPj5an1NasoYsjuQadRtM0a5qpWdnebyY1N6vS+IxIfc36hV/n1mSygauQ2momWljAy3UeU+Hh6r1jvnlo7jnNS4sztOSnBnd5yU4M7uOClhccs/EQgzkuxVY4/wrTDNtevNpl3tOtliZ0SAyYgREhqpg75/n5ZbLBGrr0+HWAKIMZc8osmwa7feVjYy7zxjiD0NRvmopBPK1Nhgt7Xmvq9aoWu5txtCHEC+qIW7WGhuJmcIaZHcBvse0evNPHqe2TYYxzf/a/cq29RKu5zSxFVvULbzvn6r2fbOZi1+7o6Il5myzhUQrNhtIBjJKY1DA+iyYWaZqcOfRT9xHOd5hTu746QEd3bHSQnu7I6TEtzZHSclLLIaLypjaiUiT1vhgR2G6g7QZSQXGBsfMdse7NcJcK0kFQCTkzo8d3xMh7u2ttrhtjuf3aVsjQ12ggYrk2zZyPYKkM9rxbexyR6bsqECNzXafVi1crmyWRlyGyOlpnI54+lBREZev04/WWltt/vVoBO+8uTtdohyK/pYhG/pkN2n7o2kSLtwuzKdZ4RYA7xpUCdY6em2z6UeQyXPZOynLRaVSNbasfHpmYZj5cXAr+yOkxrc2R0nJbizO05KcGd3nJQwn/JPHwD+BDhc1Pu6EMJ3Zl9boKrmeNvihxiC1fandpptG889U9kscQ2gp0eHpeaydohjyZhfPTSkhb+RUXtbrUbG2BgTxpzp2PItzTpkt63NFs2Gh8eVbdUKM8U/K4y55/mcnrOdy9nj1WwJd5HLSWlK729To72/bZ26bXWZzisAMNmvyzrlRvR6wzY7O+2Wsc267ZmPm21/Zes2ZXvjHvtc+Mc1Ouw4OhXdEtliCQvU3Pe4QDcXNf5w+af7RaQVuE9Ebks++0QI4aNzWIfjOEvMXBJO9gA9yethETlc/slxnJOI+ZR/AniHiDwkIp8TEXOaWX35p4qRpslxnMVhPuWfPg1sBC6gduX/mLVcffmn2JRNx3EWnjk5u1X+KYTQG0KohNqk3M8CL1m4bjqOM1/mosab5Z8O13lL3r4BeGQO6yJXmL7J2K29lU8iFgrY23tA2fqNsFiAQlGHZO7bb5epayjoJA+NRqhpzkgmAVAyQoGbjO0DtLRoJTumxncZ4b2xTLTrjNp0IRJ6SdDf/Q2NWvmP6b0jIzp8tKOrK7Itw1S17/xOuUj3obTCruX30Nd3K9vw/RuVrQWt2gNM7dH9Hdj3m2bbJ07VTyXGW+wQ58ygEb6dtRX2nFjH0m5bUY4ST14xn/JPV4vIBdQO207gT+ewLsdxloj5lH86wjN1x3FOJDyCznFSgju746SERZ3PXq0GJienl9fJRovXawWnUrbFvP0H9ITnpiY7g2pTkxbIJiPlfSqGiFQw+muV6wFoMwS2jnY7TLPByEpqzg8HskbG2Dz2/k6MacEol7dFnFJJj4M15u0d9vx9qwxXtWyXabKOe7VqH98wpce3OWtnuO14iRbu+tB9qPKMufyK8inKtvMhHUILMPgiHV5cPEsLhAD5b95h9ME+bzLG8c3mI1l6Z2SizczyeNuv7I6TEtzZHScluLM7TkpwZ3eclODO7jgpYdFrvVWr09XhWDH5mSojQKyMlakCR+q3jYxOGFZ7xVbts2XdOpyyXLIV55ZWrRg3NuhsrWAnpLDqtAFk81o1X3d+j9ESMkZI5uBOO1S0Oq6fCFSNBB5FI4y41i99zKaMJBUAQwMDytZkjQGAEd67wsiEC/DM7j26X5t0ootsg33qh6zu18HKXWbb7due1sY9kay1xrlUiCQByRf0E6M3vvYSs+3Zp0/P0vueD/2DvX38yu44qcGd3XFSgju746QEd3bHSQmLKtBlMhkaG6aLMLEssNmc/h6aighhGPZSxZ5X3NykRXXBr00AAAxKSURBVKBcZC54d5eeN54zwhG7OuySP1bYo2Tt71dLUKxGSmOR1YLVJPb8/fWn6JDd5i4tYgFUy4bIV9KC4n5DlwJobdOhwK2GSAnQbMzfL03Zx9c6ZlNT9nnTaRyLPT06nHp82A7NfeDJHytbxQi3BShP6eOTLdvn0vLlulzVuZt0CSyAa95wmbJ1tNvjWJyR/Lex0Rb9wK/sjpMa3NkdJyW4sztOSnBnd5yUMJeEkw3AT4Bi0v6mEML7ReQ04CtAN3AfcE0IIRI+lGwsm2VZ53TBaHePFdEG5YoWRWLfTJa4lYm0LhjRX4WCPQzWPPdmI3Irl9clkmp2LeaVJ22xZ8rYlkQi6JoKOhqr0rfObFtep4W7TD4yx9zcDV2PvmHdE+by9L9UmaaMOfIATc1acJJmW5A81K8Ftr6+g2bb+x7QeU+3btNlmmJZM8vG/P3Vq+xa8OtW6zJaElnxu956hbIVInPUc0Y9+GpEcN7bP71tyW4GzO3KPglcHkI4n1qO+M0icinwEWrln84ADgFvmcO6HMdZIo7o7KHG4Ty4+eQvAJcDNyX2G4ErF6SHjuMcF+ZaJCKbpJHeD9wGPAUMhBAO3zTsJlL/rb78U2m2ewzHcRaUOTl7UvnlAmAdtcovZ891A/Xln/KR3yiO4yw8R6XGhxAGgDuBlwIdIs+VrlgH2GFZjuOcEMxFjV8OlEIIAyLSCLyamjh3J/BGaor8tcAtR1pXd0cLb77yFdNsH/zHb9gdM0oRZSKhplYJqXKkrNTQ0JDuV7ettk4Z4ZsT4+PKVszbIYqT4/pnSyWWQdWafx+ZwH/GubqUUNdaWwUul411RMpoWVWhrDHPN9oKe65Z26tle+77lOi2T27bbrZ9fKtW/+/f8rDZFiNrrXVH2daqw3UBfuW8M5TtVy98gdn2pS8+S9liId3ByJkQIsehXDKeLmXsJxVf+Nn0RygHjSpTh5nLffVq4EYRyVK7E/haCOHbIvIY8BUR+RvgAWr14BzHOUGZS/mnh6jVZJ9pfxqv3Oo4Jw0eQec4KcGd3XFSwuI+CwtQKc+tnnSprMWtQkQIM5NLRsSPCWP+/NDQsNnWKhV18JBOSJiPlGnKZIykmUbiwVq/dNhwNTJ9//Gf6nF4xR/aolnVDG2IJPk0dmNkSI9XCPvM5Xds/55evscMv6BU0mG4d99zn9m2WNBxvIVInfuXX6TFtI0bdILN9ohAd/ml5yibkXMTgDGrbFgk0ak15FZS1Vpb3TgfqeU+M0lnJlpOza/sjpMa3NkdJyW4sztOSnBnd5yU4M7uOClhUdX4XC7D8hnJKza//Dyz7fd/+qCyhUhiAMuezdq7Zqn8pZKdqbRc0evYf1Ang2gs2E8Jmpp0ZlYrtBfsMk3ZyMShqTHdds/DdumklWePKttAn72/VSO0dvs9WrV+wip7BAz068cHg/12CGw2o49ZcyQT7enrdMKQ119+odn23LNOU7b2Nq3cl9VToRqjhsJuPVWp2ed+rbSeGEVEfopFfRzueNzuw8DE9HOvEiI10vAru+OkBnd2x0kJ7uyOkxLc2R0nJSyqQCciFGeIWWeeatcKv/OercpWipT8yRolmWLhsjkjtNWatw4go3ruuhUae6Bfh9AC5If15GJz3jqwZtUKvf2I1jI5qdfx7OORefKtOhR418N2uar77t6lbCODvco2Pm4fh2XL9HpPP02Xn6qh9+Ed17zGbNlqCJ0dbXadeyuPweiIDkUWsc8PsQY9InpZ5bliGYHFCp2OnKOZoNe7ayCSFyBM96eqC3SO47izO05KcGd3nJTgzu44KeGIzi4iDSJyj4g8KCKPisgHE/sXRGSHiGxJ/i5Y+O46jnOszEWNP1z+aURE8sBdIvLd5LN3hxBummXZaYQQmJhR6+zcjXZyg1cZSQS+efu9ZlsrWtVU6IGqES4bC1usGPXmska21dFxu15dZ1GHfy7r7DLbWqGXVSvdK1Ay6n49+ogO4wX4wff3Kls+f8BsWw16zJqadeKIszatN5d/5UW6nMDLLtQZWAEmjScrhUiIc8Y4lrFwV0tMnzTqzRWKdn2+7FGEwJaN45CLnXeGch/b1siklenCVuPzueknv/k04XDfop8khFq+W6v8k+M4JxHHVP4phPCL5KO/FZGHROQTImJ+9dSXfxoa0amIHMdZHI6p/JOIvBD4K2ploC4GuoD3RpZ9rvxTW4s9M8txnIXnWMs/bQ4h9CQVXieBz+M55B3nhOaYyz+JyOoQQo/UFIErgUeOpQOxckiFvP5VUMjZ88at0MdK2V6vpeDEvvGCEXpYMYSWSNQjYxNaGCqX+822DQ16znVMGHlqhw5rzUcy7xYbdbhqLKTz0vM2Ktvq5Z3K9oZXX2QuXzbG3BKmAPJZ3V9rbAEQQ9zK2aduxRLNcoYYFxE/gxHGS2S8jkZUzeWMcmbYVY1/uE2fC08csAXFjvbp45iZj0BHvPzTD5MvAgG2AH82h3U5jrNEzKf80+UL0iPHcRYEj6BznJTgzu44KcGd3XFSwqImrwghKLU0YIcXnnmKzii6stvOPrr3gE4eESmpRsZSViMKZiwT7EwKRVsJHxwcUjZL4Qc7pDMXyWpabNSJG844xU4C8spLXqhsY2M6KQfAr1+i66Q1NOh9G7dqnAFWObJsTMk27NXIpceqiWbW98NOEmEd36qRIAJArEMe8xJLeI88mimX9IpbGu0dbjLGPJe31Xj1RGAWNd6v7I6TEtzZHScluLM7TkpwZ3eclLCoAh0IMkNQqFZtQeGcTeuU7ZS1OgMrwK79OgQ1m7F3LZgCytxn7JZKRibayOJWKHAsrLWrU4elNjXqsEmAq3/7ZXr5dl2mCWDjqauULZJYlZIxR3yqpG2ZyDWiYoheEhXSjLDliCBpDXBsH6zjWzWOQyai4AZjxdlISHfGmH9fjcVOBx0au32/3YcHdmkxriFSYkwi57mFX9kdJyW4sztOSnBnd5yU4M7uOCnBnd1xUsIi13qD7AwF01bHYXRch2T+zmsuNtvu7tVq/J59dgbVjKFehojaamXqtBIWxBIpLO/WIb/lSIKG392sE/2cvWGN2bazVYfLxpImjI0Zoa3RUGJtC8YYhMjjBys01qq9lqxYbz8SSjxlKNmxJA1Wdlfr+ISqXd/PCpGuRpR766iHSHhwg9GHwQn7vOkd1cp7V2cs8+50+2zZZf3K7jgpwZ3dcVKCO7vjpAR3dsdJCRITyBZkYyIHgGeSt8uAvkXb+OLh+3Xy8Xzat1NDCFoZZpGdfdqGRe4NIdg5iU9ifL9OPp7P+1aP38Y7TkpwZ3eclLCUzn79Em57IfH9Ovl4Pu/bcyzZb3bHcRYXv413nJTgzu44KWHRnV1ENovIEyKyXUTet9jbP56IyOdEZL+IPFJn6xKR20RkW/Jf55s6wRGR9SJyp4g8JiKPishfJPaTet9EpEFE7hGRB5P9+mBiP01EfpGck18VETtJ+0nOojp7Ugn2H4DXAucAV4vIOYvZh+PMF4DNM2zvA+4IIWwC7kjen2yUgXeFEM4BLgXenhynk33fJoHLQwjnAxcAm0XkUuAjwCdCCGcAh4C3LGEfF4zFvrK/BNgeQng6hDAFfAW4YpH7cNwIIfwEmDm/9grgxuT1jdRq159UhBB6Qgj3J6+Hga3AWk7yfQs1RpK3+eQvAJcDNyX2k26/5spiO/taYFfd+92J7fnEyhBCT/J6H7ByKTszX0RkA7WS3b/gebBvIpIVkS3AfuA24ClgIITnJs0/H89JwAW6BSXUnmuetM82RaQF+DrwzhDCtMJ1J+u+hRAqIYQLgHXU7jTPXuIuLRqL7ex7gPV179cltucTvSKyGiD5v3+J+3NMiEiemqN/KYTwjcT8vNg3gBDCAHAn8FKgQ0QOp3x5Pp6TwOI7+y+BTYn6WQB+H7h1kfuw0NwKXJu8vha4ZQn7ckxILbfRDcDWEMLH6z46qfdNRJaLSEfyuhF4NTU94k7gjUmzk26/5sqiR9CJyOuATwJZ4HMhhL9d1A4cR0Tky8Bl1KZI9gLvB74JfA04hdp03qtCCDpJ3gmMiLwc+CnwMHA4ad511H63n7T7JiLnURPgstQudF8LIXxIRE6nJhZ3AQ8AfxhCmFy6ni4MHi7rOCnBBTrHSQnu7I6TEtzZHScluLM7TkpwZ3eclODO7jgpwZ3dcVLC/wcdLVLqxA8mFAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "resize = T.Compose([T.ToPILImage(),\n",
    "                    T.Resize(40, interpolation=Image.CUBIC),\n",
    "                    T.ToTensor()])\n",
    "\n",
    "def get_screen():\n",
    "    # Returned screen requested by gym is 400x600x3, but is sometimes larger\n",
    "    # such as 800x1200x3. Transpose it into torch order (CHW).\n",
    "    #env.render(mode='human')\n",
    "    screen = env._get_observation().transpose((2, 0, 1))\n",
    "    # Convert to float, rescale, convert to torch tensor\n",
    "    # (this doesn't require a copy)\n",
    "    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255\n",
    "    screen = torch.from_numpy(screen)\n",
    "    # Resize, and add a batch dimension (BCHW)\n",
    "    return resize(screen).unsqueeze(0).to(device)\n",
    "\n",
    "\n",
    "env.reset()\n",
    "\n",
    "# number of agents\n",
    "num_agents = 1\n",
    "print('Number of agents:', num_agents)\n",
    "\n",
    "init_screen = get_screen()\n",
    "_, _, screen_height, screen_width = init_screen.shape\n",
    "\n",
    "# size of each action\n",
    "action_size = env.action_space.shape[0]\n",
    "print('Size of each action:', action_size)\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(init_screen.cpu().squeeze(0).permute(1, 2, 0).numpy(),\n",
    "           interpolation='none')\n",
    "plt.title('Example extracted screen')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = SummaryWriter()\n",
    "i_episode = 0\n",
    "ten_rewards = 0\n",
    "def collect_trajectories(envs, policy, tmax=200, nrand=5):\n",
    "\n",
    "    global i_episode \n",
    "    global ten_rewards\n",
    "    global writer\n",
    "    \n",
    "    #initialize returning lists and start the game!\n",
    "    state_list=[]\n",
    "    reward_list=[]\n",
    "    prob_list=[]\n",
    "    action_list=[]\n",
    "    value_list=[]\n",
    "    done_list=[]\n",
    "\n",
    "    state = envs.reset()\n",
    "\n",
    "    # perform nrand random steps\n",
    "    for _ in range(nrand):\n",
    "        action = np.random.randn(action_size)\n",
    "        action = np.clip(action, -1.0, 1.0)\n",
    "        _, reward, done, _  = envs.step(action)\n",
    "        reward = torch.tensor([reward], device=device)\n",
    "        \n",
    "\n",
    "    for t in range(tmax):\n",
    "        states = get_screen()\n",
    "        action_est, values = policy(states)\n",
    "        sigma = nn.Parameter(torch.zeros(action_size))\n",
    "        dist = torch.distributions.Normal(action_est, F.softplus(sigma).to(device))\n",
    "        actions = dist.sample()\n",
    "        log_probs = dist.log_prob(actions)\n",
    "        log_probs = torch.sum(log_probs, dim=-1).detach()\n",
    "        values = values.detach()\n",
    "        actions = actions.detach()\n",
    "        \n",
    "        env_actions = actions.cpu().numpy()\n",
    "        _, reward, done, _  = envs.step(env_actions[0])\n",
    "        rewards = torch.tensor([reward], device=device)\n",
    "        dones = torch.tensor([done], device=device)\n",
    "\n",
    "        state_list.append(states.unsqueeze(0))\n",
    "        prob_list.append(log_probs.unsqueeze(0))\n",
    "        action_list.append(actions.unsqueeze(0))\n",
    "        reward_list.append(rewards.unsqueeze(0))\n",
    "        value_list.append(values.unsqueeze(0))\n",
    "        done_list.append(dones)\n",
    "\n",
    "        if np.any(dones.cpu().numpy()):\n",
    "            ten_rewards += reward\n",
    "            i_episode += 1\n",
    "            state = envs.reset()\n",
    "            if i_episode%10 == 0:\n",
    "                writer.add_scalar('ten episodes average rewards', ten_rewards/10.0, i_episode)\n",
    "                ten_rewards = 0\n",
    "\n",
    "    state_list = torch.cat(state_list, dim=0)\n",
    "    prob_list = torch.cat(prob_list, dim=0)\n",
    "    action_list = torch.cat(action_list, dim=0)\n",
    "    reward_list = torch.cat(reward_list, dim=0)\n",
    "    value_list = torch.cat(value_list, dim=0)\n",
    "    done_list = torch.cat(done_list, dim=0)\n",
    "    return prob_list, state_list, action_list, reward_list, value_list, done_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_returns(rewards, values, dones):\n",
    "    n_step = len(rewards)\n",
    "    n_agent = len(rewards[0])\n",
    "\n",
    "    # Create empty buffer\n",
    "    GAE = torch.zeros(n_step,n_agent).float().to(device)\n",
    "    returns = torch.zeros(n_step,n_agent).float().to(device)\n",
    "\n",
    "    # Set start values\n",
    "    GAE_current = torch.zeros(n_agent).float().to(device)\n",
    "\n",
    "    TAU = 0.95\n",
    "    discount = 0.99\n",
    "    values_next = values[-1].detach()\n",
    "    returns_current = values[-1].detach()\n",
    "    for irow in reversed(range(n_step)):\n",
    "        values_current = values[irow]\n",
    "        rewards_current = rewards[irow]\n",
    "        gamma = discount * (1. - dones[irow].float())\n",
    "\n",
    "        # Calculate TD Error\n",
    "        td_error = rewards_current + gamma * values_next - values_current\n",
    "        # Update GAE, returns\n",
    "        GAE_current = td_error + gamma * TAU * GAE_current\n",
    "        returns_current = rewards_current + gamma * returns_current\n",
    "        # Set GAE, returns to buffer\n",
    "        GAE[irow] = GAE_current\n",
    "        returns[irow] = returns_current\n",
    "\n",
    "        values_next = values_current\n",
    "\n",
    "    return GAE, returns\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_policy(envs, policy, tmax=1000):\n",
    "    reward_list=[]\n",
    "    state = envs.reset()\n",
    "    for t in range(tmax):\n",
    "        states = get_screen()\n",
    "        action_est, values = policy(states)\n",
    "        sigma = nn.Parameter(torch.zeros(action_size))\n",
    "        dist = torch.distributions.Normal(action_est, F.softplus(sigma).to(device))\n",
    "        actions = dist.sample()\n",
    "        _, reward, done, _  = envs.step(actions[0])\n",
    "        dones = done\n",
    "        reward_list.append(np.mean(reward))\n",
    "\n",
    "        # stop if any of the trajectories is done to have retangular lists\n",
    "        if np.any(dones):\n",
    "            break\n",
    "    return reward_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Network Architecture\n",
    "An actor-critic structure with continuous action space is used for this project. The policy consists of 3 parts, a shared hidden layers, actor, and critic.\n",
    "The actor layer outputs the mean value of a normal distribution, from which the agent's action is sampled. The critic layer yields the value function.\n",
    "\n",
    "- Shared layer:\n",
    "```\n",
    "Input State(48,48,3) -> Conv2d(3, 16, 5, 2) -> BatchNorm2d(16) -> Conv2d(16, 32, 5, 2)-> BatchNorm2d(32)\n",
    "-> Conv2d(32, 32, 5, 2) -> BatchNorm2d(32) -> Dense(128) -> LeakyReLU -> Dense(128) -> LeakyReLU -> Dense(64) -> LeakyReLU\n",
    "```\n",
    "- Actor and Critic layers:\n",
    "```\n",
    "LeakyRelu -> Dense(64) -> LeakyRelu -> Dense(4)-> tanh -> Actor's output\n",
    "LeakyReLU -> Dense(64) -> LeakyRelu -> Dense(1) -> Critic's output\n",
    "```\n",
    "\n",
    "### Model update using PPO/GAE\n",
    "The hyperparameters used during training are:\n",
    "\n",
    "Parameter | Value | Description\n",
    "------------ | ------------- | -------------\n",
    "Number of Agents | 1 | Number of agents trained simultaneously\n",
    "tmax | 20 | Maximum number of steps per episode\n",
    "Epochs | 10 | Number of training epoch per batch sampling\n",
    "Batch size | 128 | Size of batch taken from the accumulated  trajectories\n",
    "Discount (gamma) | 0.993 | Discount rate \n",
    "Epsilon | 0.07 | Ratio used to clip r = new_probs/old_probs during training\n",
    "Gradient clip | 10.0 | Maximum gradient norm \n",
    "Beta | 0.01 | Entropy coefficient \n",
    "Tau | 0.95 | tau coefficient in GAE\n",
    "Learning rate | 2e-4 | Learning rate \n",
    "Optimizer | Adam | Optimization method\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run your own policy!\n",
    "policy=ActorCritic(state_size=(screen_height, screen_width),\n",
    "              action_size=action_size,\n",
    "              shared_layers=[128, 64],\n",
    "              critic_hidden_layers=[64],\n",
    "              actor_hidden_layers=[64],\n",
    "              init_type='xavier-uniform',\n",
    "              seed=0).to(device)\n",
    "\n",
    "# we use the adam optimizer with learning rate 2e-4\n",
    "# optim.SGD is also possible\n",
    "optimizer = optim.Adam(policy.parameters(), lr=2e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = 'policy_ppo.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best mean reward updated 45.000 -> 45.500, model saved\n",
      "Best mean reward updated 45.500 -> 45.667, model saved\n",
      "Best mean reward updated 45.667 -> 47.250, model saved\n",
      "Best mean reward updated 47.250 -> 47.800, model saved\n",
      "Best mean reward updated 47.800 -> 48.667, model saved\n",
      "Best mean reward updated 48.667 -> 48.900, model saved\n",
      "Best mean reward updated 48.900 -> 48.909, model saved\n",
      "Best mean reward updated 48.909 -> 49.083, model saved\n",
      "Best mean reward updated 49.083 -> 49.188, model saved\n",
      "Best mean reward updated 49.188 -> 49.588, model saved\n",
      "Best mean reward updated 49.588 -> 49.952, model saved\n",
      "Best mean reward updated 49.952 -> 49.955, model saved\n",
      "Best mean reward updated 49.955 -> 50.304, model saved\n",
      "Best mean reward updated 50.304 -> 50.542, model saved\n",
      "Best mean reward updated 50.542 -> 51.160, model saved\n",
      "Environment solved in 26 seasons!\tAverage Score: 51.12\n",
      "Average Score: 51.12\n",
      "Elapsed time: 1:28:52.548061\n"
     ]
    }
   ],
   "source": [
    "writer = SummaryWriter()\n",
    "best_mean_reward = None\n",
    "\n",
    "scores_window = deque(maxlen=100)  # last 100 scores\n",
    "\n",
    "discount = 0.993\n",
    "epsilon = 0.07\n",
    "beta = .01\n",
    "opt_epoch = 10\n",
    "season = 1000000\n",
    "batch_size = 128\n",
    "tmax = 1000 #env episode steps\n",
    "save_scores = []\n",
    "start_time = timeit.default_timer()\n",
    "\n",
    "for s in range(season):\n",
    "    policy.eval()\n",
    "    old_probs_lst, states_lst, actions_lst, rewards_lst, values_lst, dones_list = collect_trajectories(envs=env,\n",
    "                                                                                                       policy=policy,\n",
    "                                                                                                       tmax=tmax,\n",
    "                                                                                                       nrand = 5)\n",
    "\n",
    "    season_score = rewards_lst.sum(dim=0).item()\n",
    "    scores_window.append(season_score)\n",
    "    save_scores.append(season_score)\n",
    "    \n",
    "    gea, target_value = calc_returns(rewards = rewards_lst,\n",
    "                                     values = values_lst,\n",
    "                                     dones=dones_list)\n",
    "    gea = (gea - gea.mean()) / (gea.std() + 1e-8)\n",
    "\n",
    "    policy.train()\n",
    "\n",
    "    # cat all agents\n",
    "    def concat_all(v):\n",
    "        #print(v.shape)\n",
    "        if len(v.shape) == 3:#actions\n",
    "            return v.reshape([-1, v.shape[-1]])\n",
    "        if len(v.shape) == 5:#states\n",
    "            v = v.reshape([-1, v.shape[-3], v.shape[-2],v.shape[-1]])\n",
    "            #print(v.shape)\n",
    "            return v\n",
    "        return v.reshape([-1])\n",
    "\n",
    "    old_probs_lst = concat_all(old_probs_lst)\n",
    "    states_lst = concat_all(states_lst)\n",
    "    actions_lst = concat_all(actions_lst)\n",
    "    rewards_lst = concat_all(rewards_lst)\n",
    "    values_lst = concat_all(values_lst)\n",
    "    gea = concat_all(gea)\n",
    "    target_value = concat_all(target_value)\n",
    "\n",
    "    # gradient ascent step\n",
    "    n_sample = len(old_probs_lst)//batch_size\n",
    "    idx = np.arange(len(old_probs_lst))\n",
    "    np.random.shuffle(idx)\n",
    "    for epoch in range(opt_epoch):\n",
    "        for b in range(n_sample):\n",
    "            ind = idx[b*batch_size:(b+1)*batch_size]\n",
    "            g = gea[ind]\n",
    "            tv = target_value[ind]\n",
    "            actions = actions_lst[ind]\n",
    "            old_probs = old_probs_lst[ind]\n",
    "\n",
    "            action_est, values = policy(states_lst[ind])\n",
    "            sigma = nn.Parameter(torch.zeros(action_size))\n",
    "            dist = torch.distributions.Normal(action_est, F.softplus(sigma).to(device))\n",
    "            log_probs = dist.log_prob(actions)\n",
    "            log_probs = torch.sum(log_probs, dim=-1)\n",
    "            entropy = torch.sum(dist.entropy(), dim=-1)\n",
    "\n",
    "            ratio = torch.exp(log_probs - old_probs)\n",
    "            ratio_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)\n",
    "            L_CLIP = torch.mean(torch.min(ratio*g, ratio_clipped*g))\n",
    "            # entropy bonus\n",
    "            S = entropy.mean()\n",
    "            # squared-error value function loss\n",
    "            L_VF = 0.5 * (tv - values).pow(2).mean()\n",
    "            # clipped surrogate\n",
    "            L = -(L_CLIP - L_VF + beta*S)\n",
    "            optimizer.zero_grad()\n",
    "            # This may need retain_graph=True on the backward pass\n",
    "            # as pytorch automatically frees the computational graph after\n",
    "            # the backward pass to save memory\n",
    "            # Without this, the chain of derivative may get lost\n",
    "            L.backward(retain_graph=True)\n",
    "            torch.nn.utils.clip_grad_norm_(policy.parameters(), 10.0)\n",
    "            optimizer.step()\n",
    "            del(L)\n",
    "\n",
    "    # the clipping parameter reduces as time goes on\n",
    "    epsilon*=.999\n",
    "    \n",
    "    # the regulation term also reduces\n",
    "    # this reduces exploration in later runs\n",
    "    beta*=.998\n",
    "\n",
    "    mean_reward = np.mean(scores_window)\n",
    "    writer.add_scalar(\"epsilon\", epsilon, s)\n",
    "    writer.add_scalar(\"beta\", beta, s)\n",
    "    # display some progress every n iterations\n",
    "    if best_mean_reward is None or best_mean_reward < mean_reward:\n",
    "                # For saving the model and possibly resuming training\n",
    "                torch.save({\n",
    "                        'policy_state_dict': policy.state_dict(),\n",
    "                        'optimizer_state_dict': optimizer.state_dict(),\n",
    "                        'epsilon': epsilon,\n",
    "                        'beta': beta\n",
    "                        }, PATH)\n",
    "                if best_mean_reward is not None:\n",
    "                    print(\"Best mean reward updated %.3f -> %.3f, model saved\" % (best_mean_reward, mean_reward))\n",
    "                best_mean_reward = mean_reward\n",
    "    if s>=25 and mean_reward>50:\n",
    "        print('Environment solved in {:d} seasons!\\tAverage Score: {:.2f}'.format(s+1, mean_reward))\n",
    "        break\n",
    "\n",
    "\n",
    "print('Average Score: {:.2f}'.format(mean_reward))\n",
    "elapsed = timeit.default_timer() - start_time\n",
    "print(\"Elapsed time: {}\".format(timedelta(seconds=elapsed)))\n",
    "writer.close()\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXycV3no8d+ZGW0jjbVLtrV5k5ckxCvZSbwAYWsJhLIEWpZeQrlkoQUuKRRKS3tvL5ReslCahJZQlgRIAkkDTUhkK7uVeLdjyZGs1bKsbbTOSJrt3D9mRpZlLSNp3pl3Zp7v5+OPpdG8M+fVSM8cPe9znqO01gghhEgdlngPQAghRGxJ4BdCiBQjgV8IIVKMBH4hhEgxEviFECLF2OI9gEgUFRXpVatWLepYl8tFdnZ2dAdkcnLOqUHOOTUs5ZwPHjzYp7Uunn57QgT+VatWceDAgUUdW1tby86dO6M7IJOTc04Ncs6pYSnnrJRqm+l2SfUIIUSKkcAvhBApRgK/EEKkGAn8QgiRYiTwCyFEipHAL4QQKUYCvxBCpBgJ/EIIEQPD415+c/hMvIcBSOAXQoiYePTAGf7yl0fpcLrjPRQJ/EIIEQvNfaMA9IyMx3kkEviFECImWvpcAPSOTMR5JBL4hRAiJlp6Q4F/1BPnkUjgF0IIw415/JwdCqZ4kn7Gr5TKU0o9qpRqUErVK6WuVkp9SynVqZQ6Evr3HiPHIIQQ8dba75r82AyB3+i2zHcDT2utP6SUSgfswI3A/9Na/7PBzy2EEKYQzu+nWRV9o0kc+JVSucD1wKcAtNYewKOUMuophRDClMKB/7KyXFPM+JXW2pgHVmoL8ABwEtgMHATuBL5C8M1gGDgAfElrPTDD8bcCtwKUlpZuf+SRRxY1jtHRUXJychZ1bKKSc04Ncs6J40fHJ3ijz8/GQgtvOgN8b6c94mOXcs67du06qLXeMf12IwP/DmA/cK3Wuk4pdTfBYH8f0Ado4NvACq31Z+Z6rB07dmjZgStycs6pQc45cdz8w1dIt1q4rGwZ//lqGw3ffheRZj+WuAPXjIHfyIu7Z4AzWuu60OePAtu01t1aa7/WOgA8CFxh4BiEECLuWvpcrCrKptiRwYQvwMiEL67jMSzwa63PAR1KqQ2hm/YAJ5VSK6bc7QPACaPGIIQQ8Tbo9uB0eVgTCvwAfXHO8xtd1XM78PNQRU8z8GngnlD+XwOtwOcMHoMQQsRN+MLu6qJsMtKCc+3ekQnWFMfvWoWhgV9rfQSYnl/6UyOfUwghzCRcw7+6OBuvPwBAb5xLOo2e8QshREpr6XVhtSgq8u2MjHuB+Kd6pGWDEEIYqLnPRUV+Fuk2C/n2dKwWFfcZvwR+IYQwULiiB8BiURRmp8d9EZcEfiGEMIjWmpY+F6tDgR+g2JFBX5w7dErgF0IIg/SMTOD2+FkzJfAX5WTIjF8IIZJVc2+4lPN86WaxQwK/EEIkramlnGHFjgz6XRMEAsa0y4mEBH4hhDBIS5+LDJuFFcsyJ28rysnA69cMjXnjNi4J/EIIYZDmXherCrOxWM43ZAu3bYhnSacEfiGEMEhL3+gFFT0AxTnx79cjgV8IIQzg8wdod7ovyO8DFDvSAZnxCyFE0ukcHMPr1zPM+IP5/nhW9kjgF0IIA4S7cq6ZFviXZdlIt1pkxi+EEMlmajvmqZRSFOXEt22DBH4hhDBAS58LR6aNguz0i74W77YNEviFEMIALX0u1hRlz7i3brxX70rgF0IIAzT3ui5K84TFu1+PBH4hhIiyca+fs0NjF/TomarYkYHTNYE/Tm0bJPALIUxPa81P97cx7Ilff5uFaHe60ZqLavjDih0ZBDQ4XfHJ80vgF0KYXlu/m2/89gQvd/riPZSITHblLJw91QPxq+WXwC+EML02pxuAvrFAnEcSmXAp56oi+4xfD/fr6YtTLb8EfiGE6bWH2hv3jyVGqqelb5RiRwaOzLQZv14sM34hhJhbewLO+Ger6AEoinOHTgn8QgjTOx/4NVqbf9YfruGfTXa6law0a9w6dErgF0KYXlt/MPCP+2F4zNwXeIfHvfSNeuac8Sulgou4ZMYvhBAX01rT4XSzIjfY1bJjwB3nEc2tdfLC7uyBH4hrvx4J/EIIU+t3eXB5/Fy9thAItjs2s9m6ck4X7NcjgV8IIS4Szu9fs7YIgDMD5g78zb0ulILKwplLOcPi2bZBAr8QwtQ6QoF/c3kuGVboNHngb+lzUZ6fRYbNOuf9ih0ZDLi9eP2xr1SSwC+EMLXwhd2KAjuFWYozJs/xB0s5Z+7RM1V4EVd/HNozS+AXQphau9NN6bIMMtOsFGVZTJ3j11rPW8oZFs+2DRL4hRCm1t7vpqogGEiLMpWpc/x9ox5GJ3ysmie/D+dn/L2j40YP6yIS+IUQptbudFNREAykRVmKoTEvI+PeOI9qZpPbLRZHkOoJzfj7RpIs1aOUylNKPaqUalBK1SulrlZKFSilnlVKNYb+zzdyDEKIxDXu9XNueJyq0Ay6MCsYssya7mnpGwXmL+WEqTP+5Ev13A08rbXeCGwG6oG7gBqtdTVQE/pcCCEuEr6QWzllxg/mrexp7nORbrWwMi9r3vtmpllxZNiSK8evlMoFrgf+HUBr7dFaDwLvB34SuttPgJuMGoMQIrGFa/jDqZ7CUOA3a56/pddFVaEdq+XifXZnEq+2DTYDH3s10Av8WCm1GTgI3AmUaq27Qvc5B5TOdLBS6lbgVoDS0lJqa2sXNYjR0dFFH5uo5JxTQyqcc01rMJff2XCEkRaF1eMmzaJ49dibVHla4zu4GZxod7Pcbon4dbH5x2jqGJvz/ka8zkYGfhuwDbhda12nlLqbaWkdrbVWSs3Yak9r/QDwAMCOHTv0zp07FzWI2tpaFntsopJzTg2pcM7P/9cb2NM7+KN37kQpRW1tLRWFYHE42Llze7yHdwF/QNP37NP80bYqdu7cFNExv+48RP254TlfRyNeZyNz/GeAM1rrutDnjxJ8I+hWSq0ACP3fY+AYhBAJrMPpprLAjlLnUydleVmmTPWcHRzD4w/M2ZVzumJHfNo2GBb4tdbngA6l1IbQTXuAk8CTwCdDt30SeMKoMQghEltbv3vywm5Yeb7dlBd3J0s5FxD4i3LSGRn3Me71GzWsGRmZ6gG4Hfi5UiodaAY+TfDN5ldKqT8H2oAPGzwGIUQC0lrT7nRzw/riC24vz8+i3+XB7fFhTzc6hEXufA3/wmb8ENx7tzx//kVf0WLod01rfQTYMcOX9hj5vEKIxNczMsGELzBZwx9Wnh8slTw7OMa6Ekc8hjajlj4XORm2yYVZkZis5R+JbeCXlbtCCFOaXsoZVhaqke8wWbqnObTP7tTrEfOJV78eCfxCCFNqD3XlrCq8MHUSnhmbLc/f0jc6765b051P9cS2bYMEfiGEKbU53Sh1foYfVuLIIM1qrmZtEz4/nQNjC7qwC1CYLTN+IYSY1OF0szI3i3TbhWHKYlGszMsyVb+eDqebgI6sR89U6TYLefa0mHfolMAvhDCldufFpZxhwVp+82zI0ty78FLOsOKcjJh36JTAL4QwpZlq+MPK87NMleMPl3IuNMcP8enXI4FfCGE6bo+PvtGJWTcsL8+30zMyEfOFT7Np6XNRlJNOblbago+Nx6brEviFEKYTLuWcK9UD0DUU+92rZtLc52JV4cJn+xCc8ffJjF8IkerCpZxzpXoA0+T5W0I1/ItR7MjA7fHjmvBFeVSzk8AvhDCd8Ix/+qrdsLJQ4I9mnv+5k908dezsgo8bnfDROzKxoFYNU4UXccVy1m+eRhdCCBHS7nTjyLTNmjNfviwTqyW6tfzffeYUTb2jLF+WyY5VBREf1xq6sLvQUs6wqW0bpi9WM4rM+IUQptPudFNVaJ+1/YHNamH5ssyopXrGvX6aekfxBzR3PHyYQXfk5ZXNk105599gfSbFcWjbIIFfCGE67XOUcoaV50dvEVdTTzDof+6GNfSOTvDlXx9D6xn3iLpIS68LpWZPS82nyJEOxDbVI4FfCGEq/oDmzMDYRc3ZpivLj96GLCe7hgH4yI4K7nr3Jp6r7+ahV1ojOralb5SVuVlkplkX9dyF2RlYlMz4hRAprHt4HI8/QFXB3Pnu8nx78L6+wJKf8+TZYbLSrFQVZvOZa1fx9k0l/J/fN3D8zNC8xy6logfAalEUZMd2EZcEfiGEqbTNU8oZVp6XRUDDuSjU8td3DbNhuQOrRaGU4rsf2kxhTjq3PXyIkXHvrMdprZcc+CG4E1dvDNs2SOAXQphKxzylnGGTtfyDS7vAq7WmvmuYS1Yum7wtPzuduz+6lQ6nm6//5sSs+X6ny8PwuG/JgT/WbRsk8AshTKXN6cJqUazIzZzzfmWTi7iWlufvHBxjeNzHphXLLrj9itUF/OXb1/Pk0bP86kDHjMcuZrvFmRQ7MuiTHL8QIlW1O8coy8vCZp07PK3IzUKppS/iqu8aAeCSFRdv4/g/d63jmrWF/O2Tb/Bm98hFX29eYg1/WHFOcMYfaSXRUkngF0KYSriGfz7ptnAt/1IDf7CiZ8PyZRd9zWpRfP8jW8hOt3HbLw4x5rmwKVxLnwubRV20WcxCFTsy8PgCDI/Hpm2DBH4hhKm097vmLeUMK8vLonOJOf6TZ4dZVWgnJ2PmRgYlyzL5l49s4c3uUf7+qTcu+FpLr4vKQvu8f53MJ9Z770rgF0KYxvC4lwG3l6oIA395FGr5688NX5Tfn+6G9cV8fudaHn6tgyePnu/n09LnWnKaB6buvSuBXwiRYubryjldWX4W54bG8fkXV8s/OuGjrd89b+AH+Kt3rGdbZR5fe/w4rX0uAgFNa//SSznhwn49sSCBXwhhGuFSzkhTPeX5dnwBTfciA+apc8H8/iURBP40q4V7PrYVi4LbHz5Mm9PNhC+w6B49U0mqRwiRsiY3YImw7034oupiK3tOng0G/k0r5w/8EHyj+e6fbOZ45xBf/OURAFYVLa5Hz1R5WWnYLEpSPUKI6PH6A1FpbWC0NqebfHsayzIj28JwqRuynOwaYVmmjZXzrBmY6sZLl/PJq6s42jEIwJoozPgtFhXTLRgl8AuR5J4+cY6r/ncNf/348XgPZV4dTjeVC+hJvzJvaYu4wit2Z2v/PJu/fs8mLlmxDEemjdJlGYt67umKHOkxm/HLRixCJKkht5e/ffIEvz0SrEI5emYwziOaX1u/m80VeRHfPzPNSrEjY1GpHn9A03BumI9dUbngYzPTrPzsf1zJ2cGxBb9pzCa8iCsWJPALkYT2NfTw1ceO4XR5uHNPNcPjXn5R104goLFYohOoos3nD9A5OMYfbV6xoOPK8rIW1a+ntd/FuDcQUUXPTAqy0ynITl/UsTMpdmRMtoc2mqR6hEgiI+NevvroMT790Ovk2dP47Reu5S/fsZ41xTlM+AL0xLAfzEJ1DY3jD+h52zFPV56ftagZf3jFbiQVPbFQlJNB/6iHQMD4tg0y4xciSbzc1Mf/evQYXUNjfH7nWr749moybMHNQcILotqdbpYv4EJmLIXbMUdayhlWlp/FH97oXvBfMyfPDmOzKKpLl35xNhqKHRn4AprBMW9U/5KYicz4hUhwbo+Pbz5xgo//qI4Mm4VHP38NX33XxsmgD+cXRIXLJc2oPcJ2zNOV59vx+AMLzo/Xdw2ztjjngu9TPMVyEVfEM36l1HVAtdb6x0qpYiBHa90yzzGtwAjgB3xa6x1KqW8BnwV6Q3f7mtb694sZvBCp7vVWJ1/+9VHanW4+c+1qvnLjBrLSLw5kK/OysKhgHxyzanO6SLdaKF22sL9IppZ0LuTY+q4Rrl5buKDnMlJ4EVff6AQbuLhTaDRFFPiVUn8L7AA2AD8G0oCfAddGcPgurXXftNv+n9b6nxcyUCHEeeNeP480TPDMM69Snp/FI5+9iivXzB7E0m0WVuZlmXrG3+F0U56fhXWBF5/Lp5R0bq+K7Biny8O54XE2zdCKOV7MOOP/ALAVOASgtT6rlDLPd0yIFOKa8PGhf3uV+i4fH7+ykq+9ZxPZs3SWnKqywE6biQN/u9Md8YrdqRazIUv4wu5iK3qMYMbA79Faa6WUBlBKRXrZXQN/CB13v9b6gdDttyml/gw4AHxJaz0w/UCl1K3ArQClpaXU1tZG+JQXGh0dXfSxiUrOOXlprXng+AQNXX5u3aS5Jr+f1199KaJjbRMTnO7xmfL7pLXmdLebUqttzvHN9jo70uD1k6epVWcier6nW4L76A40n6C20xzlrVprbBY4VN9EbaB98nYjfrYjDfy/UkrdD+QppT4LfAZ4MILjrtNadyqlSoBnlVINwA+BbxN8U/g28L3Q410g9CbxAMCOHTv0zp07IxzqhWpra1nssYlKzjl5/fpAB6+ePcYX317NFtvZBZ3zSZp44elTvPXq6yL6CyGWBt0exp55lqvfUs3Ot62Z9X6zvc6rT7xEICudnTuviOj5nuw5Qomjjz++cddih2yI0rq9ZOUVsHPnlsnbjPjZjqiqJ5SPfxR4jGCe/5ta63sjOK4z9H8P8BvgCq11t9bar7UOEHzziOyVEiLFNfWM8M0n3uCqNQXcvrt6wcebubKnbYHtmKcry8uicwH9ek6enb8HfzzEatP1eQO/UsqqlNqntX5Wa/0VrfWXtdbPRnBcdvg6QCg19E7ghFJq6rK8DwAnFjt4IVLFuNfPbb84TFa6lbs/unXBF0CByYVRZgz8C+3KOV15fhadg2MR7Vnr8QU43TtqysAfq0Zt8/69p7X2K6UCSqlcrfXQAh67FPhNqI+FDfiF1vpppdRPlVJbCKZ6WoHPLWLcQqSUbz91koZzIzz06bcuuNwxbHLG32/iwL+EGf+4N0C/yzNZFjmbpp5RvH7NJRG2Yo6lYkcGRzouuuQZdZEm+kaB40qpZ4HJQmCt9R2zHaC1bgY2z3D7ny50kEKkst8d6+Lnde187vo17NxQsujHybWnsSzTZs4Zf7+bopwM7OmLu/ZQnh98wzgzMDZv4D852arBfIWJxY4M+l0efP7AkvfxnUuk3+XHQ/+EEDHU3u/mrseOsaUijy/fuGHJj1dVmG3OwO90U1mQtejjy6Ys4toyT3fP+q5hMmwWVi2g/XOsFOekozU43R5KHMa11ogo8Gutf6KUSgfWh246pbX2GjYqkbK01txd08g7L1luyj/FY8njC3D7w4dAwb0f20paFGaAlQX2mHWAXIh2p5srVhcs+vhw4I+kWVt91zAblzsMnVEv1tRafiMDf0RnrpTaCTQCPwD+FXhTKXW9YaMSKevN7lG+/1wj979wOt5DibvvPtPA0TNDfPdDly+4cdlsKgvtnBlw449BB8hIeXwBzg6NLekcl2WmkZuVNu8iLq01J7vMWdED5wN/36jH0OeJNNXzPeCdWutTAEqp9cDDwHajBiZS096GHgCef7PX8Dynme1t6ObBF1v406uqeNdlC+tPP5fKAjtev6ZraGwyLx5vwWqc8x1EF6ssL1jZM5dzw+MMur2mDfyx2nQ90t+qtHDQB9Bav0mwX48QUbW3oRubRTHo9nK4w/w7Rhmha2iML/3qKJtWLOPr790U1cc2Yy1/W6hx3GJLOcPK87Pm3Xt3sge/SdOIZgv8B5RSP1JK7Qz9e5BguwUhombA5eFg2wCfuKoKm0VRU98T7yHFnM8f4M5HjjDhC3DfLVvJTItuy+Bw4O8wUeDvWGIpZ1hZaEOWuWr567tGANi43HwVPQDZGTay062G770baeD/PHASuCP072ToNiGi5oXGXgIabtpaxhWrC9jXkHqB/569TbzW4uQfbrqMtcXR3yBkRW4mNouaXClrBm39bjJsFkocS9u0vDzfjsvjZ9A9e93JybPDVBRk4cg0b8KiyGH8Iq5IA78NuFtr/UGt9QeBewBz7F4gkkZNfQ9FOelcXpbL7o0lnOoeMdXM1GivnO7j3r2N3LytnA9uKzfkOWxWC+X55mrPHCzltC950/KyUHvmufL89V3DptlqcTbFMVi9G2ngrwGmFtlmAc9FfzgiVfn8AWpP9bBzQwkWi2L3xuBCpX2nUmPW3zc6wRcfOcLqomz+/v2XGvpcFQV20wX+he66NZOpG7LMxO3x0dLvMu2F3bBiR4ZpUj2ZWuvR8Cehj81REiAu8Gb3CK1D/ngPY8EOtQ8yPO5jTyjgrynOYXVR9mSVTzLTWvOlXx1lcMzLD27ZZnjnzEoTBX6tNe1Od1TKVcvn6cvfcG4Erc3Vg38mRTnGN2qLNPC7lFLbwp8opXYAC9/WXhjK6w/wmYde597DExE1qzKTmoZu0qyK66qLJm/btaGEV0734/b44jgy4x1qH+T5N3v56rs2xiQoVRXaGXR7GRqL/xrMfpcHt8e/5Au7ALlZaeRk2GYN/JMVPSYP/MWODAbdXjy+gGHPEWng/yLwa6XUi0qpF4FHgNsMG5VYlN8c6uTMwBj945o3u0fnP8BE9tb3cMXqggsuuu3ZVILHF+Dlpv44jsx4exu6sVoUH9puTF5/OjNV9ix2g/WZKKXmrOWv7xrGkWGb/MvArMKLuPpdxs365wz8Sqm3KqWWa61fBzYCvwS8wNPAnButi9jy+QPct6+JNUXB/iM1Dd1xHlHkOpxuGntG2b2x9ILb37qqgJwMG3sT6FwWo6a+h7euyic3KzaVJpUmas/cvsQ+/NMFa/lnDvzhHvxLvYhstFjU8s83478fCK8dvhr4GsG2DQOEdscS5vDEkbO0O93c9e6NVC2zsDeBauDDefzwBd2wdJuFt1UXsbehJ+FSV5HqHByj4dwIe6a96RmpItQMzQwlneE3n2itIg7W8l98XoGApuHciKk2V59NLPbenS/wW7XWztDHHwEe0Fo/prX+BrDOsFGJBfEHNPfta2LTimW845JSthRbOdQ+wIDL2H4f0VLT0MOaomxWF13cLXH3xhK6hyd446z5GotFQ3itwq6Ni2+3vFCOzDQKstNNMeNv63ezfFlm1BaqlednMTzuu+j6RbvTjdvjN+2K3amKctIBDK3smTfwK6XCJQZ7gL1TvmauTTtT2FPHztLS5+LOPetQSrG52EpAB/vdmJ1rwsf+0/0XzfbDdm4oQSmStrpnb0MPVYV21hbHtkVwsLLHNf8dDdYRquGPlvBfDtO7dIYv7Jq9ogfMkep5GHheKfUEwSqeFwGUUuuAhezGJQziD2ju3dvEhlIH77xkOQCrci0U5aRTkwDB8uWmPjz+wKyBv9iRwebyvKQM/GMePy839bF7Y0nM885mKelsd7qX3KNnqtkWcZ3sGsaiYH2p+VM9mWlWlmXa4hf4tdb/CHwJeAi4Tp9PtFqA2w0bVRLrGR4nEMWWuP99ooumnlFu37MOS2gfVotS7NpQwvOnevD5jSsJi4a9DT04MmzsWDV7L/bdG0s4emYwJnuRxtKrzX1M+GZ/0zNSVaGds4PjeOP48zHu9XNueDzKM/6ZF3HVdw2ztjgn6r2PjFLkyDC0NfO85Zxa6/1a699oraduufim1vqQYaNKUgMuD2/7zj6++WR09pcPBDT31DSyriSHd09r3bt7YwnD4z4Othm/f+diaa3Z29DD9euLSbfN/qO4e2MJWkNtkq3iranvITvduqQNSBarosCOP6A5O08bYyOFg3M0SjnDCrLTyUyzzJDqGUmINE+Y0W0bUrPZeZzUtTiZ8AX42f52fnesa8mP98wb53ize5Tbd6/DarkwVXBddRFpVmXqFMkbZ4fpGZmY98LmpSuXUbosI6naN4Tf9N5WXUyGLfazUDO0Zw5XFUVrkxkI1vKX59svKOkcdHvoHBxLrMDvMHb1rgT+GKpr6SczzcLm8lzueuzYkhbQBALBLQrXFGXzvstXXvR1R2YaV64uNHXgr6nvQSnYuaF4zvspFezd88KbfYauZoylhnMjdA2NxyXNA+dn2fEs6WyPUjvm6aYv4gq3Yk6Eip6wopwM+mTGnxzqmp1sq8znvlu2gYLbHj686ED2XH03DedGuG2G2X7Yro0lNPaMTi6SMZu9p3rYUpE3WcUwl90bSxmd8PF6q3Pe+yaC8Bvyzo1zv+kZpdSRSbrNEtfVu+1ON9npVgqz06P6uNM3ZDlf0WP+C7thxY4MRiZ8jHuN6bslgT9Ghtxe6s8Nc+XqQioK7Hzn5ss52jHIP//h1PwHT6O15p69jVQV2vnjzRfP9sPCDc/MuPK1d2SCox2D7N4Q2Yz32nWFpNsspv4LZiFq6rvZXJ5r6Ibac7FYFBVxbs/c3h9szhbtiqay/CwG3F5cE8EeTye7hinKSY/b93oxjF7EJYE/Rl5vdaI1XLkmeCHv3W9ZwSeuquSBF5oXvOHIvlM9nOgc5gu71s25J+2qomzWFGez95T56vnD+frdmyIL/PZ0G1evMXfqKlJOl4fDHYMxXbQ1k8oCe9xTPdG8sBs2WcsfSvfUm3hz9dkUh2v5DcrzS+CPkbqWftJtFrZU5E3e9jfvvYSNyx381a+OcG5oPKLH0Vpzd00T5flZfGBr2bz3372hhP2n+ydnP2axr6GH5csyF9Qpcc+mElr6XDT3JlYDuulqT/WgNTFt0zCTygI7HU53XNphBAJ6cgOWaJus5R8Yw+sP0Ng9avqOnNPJjD9J1LU42VKRd0EdcWaalftu2ca4N8CdjxzGH0F9/wuNfRztGOQLu9aRNsdsP2z3phI8/gAvNfUtafzR5PEFeOHNXnYtcOHSrg3h1FViz/prGnoocWRwaZwvNlYWZjMy4WNgjq0KjdI7OsGEL2BI4K+YUst/uncUjz+QeDP+UOA3qm2DBP4YGBn3cqJziKtmqNdeV5LDt2+6jLoWJ/fUNM75OFpr7n7uTVbmZnJzhFvzvXVVAY4Mm6n2r32txYnL45+8BhGpigI7G0odSwr8Wmu+8dsT3PHw4UU/xlJ4/QFeONXLrtBOY/EUz5LON7uDlTaVhdFvVVGUk0G61cKZwbGEatUwVUHogrfM+BPYgdYBAhquXFM449c/tL2cD24t4969jbx6evbe8y839XOofZDP71o354KnqdKsFq5fX8zehp6orhheiiud93UAAB2FSURBVJqGbjJsFq5dVzT/nafZtbGE11qcDI8vbpb66MEz/HR/G8+e7I7L9+NA6wAjE76Ir20YKZxfj0fg//HLreTZ09hRlR/1x7ZYFGWh9sz1XSOk2yysiXEvpKVKs1ooyE6XwJ/I9rf0k2ZVbKuc/Yf82zddxqrCbO585DD9M/x5F8ztv8nyZZl8eMfCNuzYvbGEnhFzdLgML1y6em0hWekLX7i0Z1MJvoDmxTcXnrpq6hnhm0+8QXa6lTGvf85NuY2yt6GbdKuF6xbxphdtFaGLoO39sW3WdvzMEHsbevjs29YYts1keX4WnQPBGf/60pyI0qJmU5xj3N67iffdSEB1zU4uL8+bM9BlZ9i495atDI55+dKvj140G93f7OT11gE+v3Ptgld67txQbJoOl819Ltr63QtO84Rtrcgjz5624HMZ9/q57ReHyUq38t0/2QxAY8/IosawFDUNPVy1ttDwfXUjkZVupdiREfMZ/901jeRmpfFnV1cZ9hxlecFa/pNnh9m0PLHSPGFFDpnxJyzXhI/jnUNctWb+fiyXrszlG+/dRO2pXn70UvMFX7unppESRwYfeWvFgsdQmJPBloo8U9TzhzeIWWwpo81q4Yb1xdSe6onoYnjYt586ScO5Eb734c1cuzY4226M8faUrX0umntd7J5npXIsVcW4pPNE5xDP1Xfz59etvmCbzWgrz8+ib9RDv8uTUCt2pyo2cNN1CfwGO9g2gD+guXL1zPn96T5xVRXvvmw533n6FIfagw3WXmtx8mpzP5+7Ye2iuwvu2VjC0TND9IxEVjZqlL0NPWwodSxpx6XdG0vod3k4emYwovv/7lgXP69r53PXr2HXhhJy7WmUODJo7Ilt4D+/01h8yzinCpd0xsp9e5twZNr45DWrDH2esin76ibahd2wYkcGfSMeQ8ptDQ38SqlWpdRxpdQRpdSB0G0FSqlnlVKNof+jf3XHROpa+rFaFNsjvIillOKfbr6c0mWZ3PHwYYbGvNy7t5GinHRuuaJy0eMIB5vaOC7mGhrz8nqrc8kXNm9YX4zVoiKqVGrvd3PXY8fYUpHHl2/cMHl7dWlOXAJ/dUlOVPvPL1VloZ2u4XEmfMa0Bpiq4dwwT79xjk9fu9rw/YWnTiwSNtWTk8GY18+4AS9NLGb8u7TWW7TWO0Kf3wXUaK2rgZrQ50mrrtnJW8pyF5TTzc1K495btnJuaJw/+/c6Xmzs49br1yzqYmjYphUOVuRmxnUv3hcbe/EF9KLz+2F59nS2V+ZTM8+5eHwBbn/4ECi492NbL7jAV13ioKl7JGaLl0YnfNS1zL7TWLxUFtjRmlk3KI+me/c2kZNh4zPXrjL8ucKLuMryssi1x2YT+2gL1/IPTSTYjH8W7wd+Evr4J8BNcRhDTIx5/Bw9MzjZpmEhtlXm8+UbN3D0zBAF2el84qqlXQhTSrFrYwkvNvbGZHY3k70NPeTZ09g6R3VTpHZvKuFk1zBdQ7MHrO8+08DRM0N85+bLL2r9u64kB5fHT1eEK6aX6qXGXrx+bcrAD8aXdDZ2j/D741188poq8uzRbco2k9JlmdgsKmHTPHA+8A97oh/4jS4t0MAflFIauF9r/QBQqrUON6M/B8yY8FRK3QrcClBaWkptbe2iBjA6OrroY5fqZL8fr19jH+2ktnbhF1bXa827VtlYm6d47ZWXIj5utnMu8fpwefw8+NtaLiuKbQ/4gNY8e9zNZUVWXnzh+SU/nmM02NX0/idfYmdF2kXnfKTHx4OHJthdaSOr/xS1tRc2w3M5g29+jz37Mm8pNr7C5ufHJ7DbYLT1GLXt0Vm4FY2f7cGJ4PexZv9RVJdxM+N/OzpOugU2qi5qa88t+nEWcs7vWmVjXdZg3H7/l+rsaIA1uRbGx8aifw5aa8P+AWWh/0uAo8D1wOC0+wzM9zjbt2/Xi7Vv375FH7tU3/vDKb36rqf08Jgnps872zm7J3x6/dd/r7/15ImYjkdrrQ+0OnXVV5/STxzpjMrjBQIBfd3/rdF//tBrWusLz7lrcExv+btn9Lu+/4Ie8/hmPL5/dEJXffUp/eALp6Mynrn4/QG9/dvP6tt+cSiqjxuNn+1AIKA3/M3v9d//1xtLH9AsmnpG9Oq7ntL/+/cnl/xY8fx9jpelnDNwQM8QUw1N9WitO0P/9wC/Aa4AupVSKwBC/8e/uNwgdc39XLoy19CytYXISrdyzdpCaup7Yt6Ya19DD1aL4obq6JQyKqXYvaGEl5r6LuhZ7g9o7nzkMBO+APfdsnXWKqiC7HQKs9NpisEF3uOdQ/SNTiz52oYRlFKGb7z+g71NZNisfPZtawx7DrEwhgV+pVS2UsoR/hh4J3ACeBL4ZOhunwSeMGoM8TTu9XO4Y5Ar47Cf6lx2byql3enmdG9sV2vWNPSwvSo/qhfadm8qZdwbuKDNxT01jdS1OPmHmy5jbXHOnMevK4lNZU9NQw8WFaxGMqPKgmzDNutp7XPx2yOdfOKqyog23BGxYeSMvxR4SSl1FHgN+J3W+mngn4B3KKUagbeHPk86RzsG8fgCs/bniZfwxcVYNm07G2qWFe0Z75WrC7CnWyfr41853cc9exu5eVs5H4ygiV11aQ6NMajs2dfQw7bKfPKjvNNUtIRn/EZ8H36wr4k0q4XPXi+zfTMxLPBrrZu11ptD/y7VWv9j6PZ+rfUerXW11vrtWuvk2EtvmroWJ0rBFavMNeMvy8ti43IHNTFcxTu56UqUA39mmpVr1xWxt6GH4QnNFx85wuqibP7+/ZdGdHx1iYPhcZ9hy+IBuofHOd45ZIqmbLOpKrQz5vVHfZVoh9PN44c7ueXKyoTa/SoVyMpdg+xv7mfj8mWmrCHevbGE11sHGBqLTR/2vfU9VBRksa5k7tTLYuzZWELn4BjfPTDO4JiXH9yyLeI1E9Wh8RiZ7gn/ZRXvTVfmEi7pjPYK3n+tbcJqUfzFDWuj+rhi6STwG8DjC3CofcB0+f2w3RtL8Ac0LzYav4p33Ovn5dN97NlYGvW9VeF8z5+OkQDffN8lC6rbXlcaCvzdxjVrq2nooSwvi/Wl0X/Ti5YKA2r5zwy4efTgGT761gpKl8ls32wk8Bvg2JlBxr2BiBqzxcPWyvxgh8sYrOL97eFOxr0B3nmJMTPe0mWZ7NxQzDUrbXz8yoW1tCjOySA3K82wGf+418/LTX3sXuBOY7FWnp+FUkS1Wdu/PX8aQGb7JiWB3wB1LcHLFldE2Jgt1qwWxa4NJexbYIfLhfL6A/ygtonN5blcvda478VDn76CWy/PWHBwVUpRbWBlT12LE7fHb+r8PgSvlSxflhm1GX/X0Bi/ev0Mf7KjgpV5WfMfIGJOAr8B9jf3s6HUMbl9mhnt2ljCgNvLkY7IOlwuxm8Pd9LhHOOOPdWmnfFWl+YYVsu/t76bzDQLV5ussmsmlQX2qJV0/lvtaQJa83mZ7ZuWBP4o8/oDHGwbWFR/nli6oTrY4dKoHv0+f4D79jVx6cplputPM9W6EgdOl2fGXc+WQmvN3lM9XLeuaNGttGMpWou4uofHefj1Dm7eVn5RfyRhHhL4o+xE5xBujz/i/vvxkhva73S+DpeL9eTRs7T1u0092wfjKnuaekbpcI6Zqvf+XCoL7PSMTDDmWVoDv/ufb8Yf0Hxh17oojUwYQQJ/lJ3P75t7xg/B6p6GcyOcjfLes/6A5r69TWxc7uAdm8wd+KoNquypaTBm7YJRwnsEdAwsftbfMzLOz+va+MDWMlPtOSAuJoE/yuqa+1lbnD3ZUtXM9oQuOkZ7L96njp2luc/FnXuqsVjMO9sHWL4sE0eGLeoz/r0NPVyyYhnLcxOjlHGyPfMS8vw/erEFrz8gs/0EIIE/ivwBzYHWAdO1aZjN2uIcKgvsUQ38gYDm3r1NrC/N4cZLl0ftcY2ilGJdaU5U998ddHs42DYw+caaCKoKswFoW2Sev390gp++2sb7t5Sxuig7mkMTBpDAH0Unzw4zMuEz7cKt6ZRSvH1TKS829nLy7HBUHvO/T5yjqWeU23ebf7YfFu2Szpea+vAHNDs3JE7gz7enkZNhW/Tq3R+91MK4zy+z/QQhgT+K6lqCXSKvSpAZP8AXdq0l357ObQ8fwjXhW9JjBWf7jawtzuY9b1kRpREar7rEQd/oBAMuT1Qeb39zPzkZNjaX50bl8WJBKUXFIit7Blwe/vOVVt53+UpD2nKI6JPAH0X7m52sKrQn1BL1wpwMvv/RLbT0ufjGEyeW9Fh/ONlNw7kRbt9djTVBZvtwvnVDU290Zv11zU62V+VjsybWr1dVgZ22/oW36/73l1pwefzcvltm+4kisX4yTSwQ0Lze6jR9GedMrllbxO27q3n8UCePHTyzqMfQWnNPTSOri7J53+WJM9uHKSWdUcjz949O0Ngzavp1HDOpLLTTMTBGYAGruYfcXh56pZX3vGU560sdBo5ORJME/iipPzfM0Jg3IX/hAe7YvY4rVhfwjSdOcHoRM9/n6ns42TXMF3atS7iZ7srcLOzpVhp7ll7S+VqonDcRJwCVBXY8vgDdI5FvQP8fL7cwOuHj9t3VBo5MRFti/YaaWF1z6Bc+gfL7U9msFu756FYybBa+8PNDF2xnOJ/wbL+ywM5NW1YaOEpjWCyKdSXRad1Q1+IkK83K5QmU3w9baEnn8LiX/3i5hRsvLV1QV1QRfxL4o6SupZ/y/CzKErgp1fLcTP7lw1toODfCP/6uPuLjak/1crxziNsScLYftq4kOiWd+5v72V6VT1oCfh8qF9ie+ScvtzIyLrP9RJR4P50mFAhoXmtxJlQ1z2x2bSzhs29bzU/3t/Hfx7vmvb/WmrtrGinLy+ID28piMEJjVJc4ODc8zvD44jenGXR7ONU9kjDlvNOV5WdhUZEF/tEJHz96qYW3byrhsrLE++sm1Ungj4LGnlEG3N6E/YWf7is3bmRzRR7/67Fj89Z1v9jYx5GOQb6wa11CznLDwhd4l5Luea3FidaJm+5Ls1pYmZcVUeD/ySutDI15ZbafoBL3N9VEErF+fy7pNgv3fnQraLj94cN4/YEZ7xee7a/MzeTm7Yk724fzPXualpDuqWtxkmGzsLkicWfAVYX2eTdkcU34+NGLzezcUMzmirwYjUxEkwT+KKhrdrIyN5Py/MTN709XWWjnn26+nCMdg/zzM6dmvM+rp/s52DbA53euJcNm/tbDcynPt5NhsyypsqeupZ+tlXkJ/b2oLLDP+1fez/a3MeD2csceme0nKgn8S6S1pq6lnyvXFJq6/fBivPfyFXz8ykruf6GZfacu7ufz/ZpGli/L5MNvrYjD6KLLalGsLV5864bhcS8nzw4nZBnnVBUFdvpdHkZnWcU95vHzwAvNvK26iG2V+TEenYgWCfxLdLrXRd+oJ2ny+9N9432XsHG5gy/96ijnhs7Xd+9v7ue1Fid/ccOahJ7hTlW9hGZtB1qdBDQJu44jrKog2GBttpLOn9e10e/ycKfM9hOaBP4lCuf3E/WC3nwy06zcd8s2xjx+vvjLw5N79N5T00ixI4OPXrGwDc7NrLokh87BsUX1LKprdpJutST8LHiuks5xr5/7X2jmmrWF7FiV2G9wqU4C/xLVNTspcWSwKok3nlhXksPfv/9S9jc7uXdvI6+3OnnldD+fu35NQmwrGKl1JcGWA4tZuby/uZ/NFbkJ//0Ib6DS7ry4Z88jr7XTOzIhuf0kIIF/CZI5vz/dh7aX84GtZdxT08hfP36copx0Pn5lVbyHFVXnd+NaWOAfnfBx4uxwUlR15WalkZuVdtGMf9zr54fPn+aK1QVJcZ6pzhbvAZiF1prfHO7k4dfa8UXYpCoQ0HQPTyRtfn8qpRTfvukyjnQM0tQzyl+/eyNZ6Yk9u52uqsBOmlUt+ALvgVYn/oBO+Au7YZUFF5d0/vpAB93DE/zLh7fEaVQimiTwE9wr9GuPn+C5+m7Wl+YsqK3yOy8pTYidpqIhJ8PG/X+6nV/UtfOnVyfXbB+C/YrWFOXQtMCSzroWJzaLYltVctS0VxbaeaNzaPLzCZ+ff609zfaqfK5Zmxxvbqku5QP/fx09yzefOIHL4+fr79nEZ65bnVC95GNtfamDb/3xpfEehmHWleZw/MzQ/Hecoq65n8vLc7GnJ8evU2WBnWdOnMMf0FgtiscOdtI1NM7/vfnypE9ppoqUzfE7XR6+8ItD3P7wYSoL7Pz+juv47PVrJOinuOqSHDoG3Ix5IutO6vb4OHZmKKmquqoK7PgCmrODY3j9AX6wr4ktFXm8rboo3kMTUZIcU5QF+sMb5/jab44zNOblKzdu4HPXr0nYrpIiutaXOtA6WNkTSfOxQ22D+AI6qa7zhEs6O5xuXj3dT+fgGP9w02Uy208iKRX4h9xe/u6/3uDxw51csmIZP/3zK6WPuLjA1GZtkQT+upZ+rBaVVHXtFaHAf7rPxYMvNPOWslx2biiO86hENBke+JVSVuAA0Km1fp9S6iHgBiCcSP2U1vqI0eOoPdXDVx87Rt+ohzt2r+O23dWk22SWLy5UVZiNzaIi7tlT1+zkspXLyMlInjnUyrwsbBbFgy800+508+Cf7ZDZfpKJxU/rnUA9MHVq/RWt9aMxeG7GfJq7HjvGI693UF2Sw4N/toPLy5Oj+kJEX7rNwqqi7Ihq+ce9fo50DPKpa1cZP7AYsloU5flZtPa72bRiGW/fVBLvIYkoMzTwK6XKgfcC/wj8lZHPNZNXmvr4m5fGcE508Lkb1vCXb1+f8CsrhfGqS3I4dW7+Gf/h9kE8/kBS5ffDKgrstPa7uXPPOpntJyGldWSLlRb14Eo9CvwfwAF8eUqq52pgAqgB7tJaT8xw7K3ArQClpaXbH3nkkQU//yMNHg6e83Dr5iyq81Mn4I+OjpKTkxPvYcRUNM/58UYP/3XaywPvtJM2R5XXb5s8PNHk5b49drLTYh8cjXydn2n1crzPz19tz8BiosAvP9sLs2vXroNa6x0XfUFrbcg/4H3Av4Y+3gk8Ffp4BaCADOAnwDfne6zt27frxRjz+PTTz+1d1LGJbN++ffEeQsxF85yfONKpq776lK7vGprzfh+9/1X9nrtfiNrzLpS8zqlhKecMHNAzxFQjr25eC/yxUqoVeATYrZT6mda6KzSmCeDHwBVGDSAzzUqG1TyzFZEYwpU9c+X5J3x+DrUPJE2bBpFaDAv8Wuu/1lqXa61XAR8F9mqtP6GUWgGggonDm4ATRo1BiMVYXZSNRTFnz56jHUNM+AIJ339fpKZ41KD9XClVTDDdcwT4iziMQYhZZaZZqSrMnrNnT11zcB+GK5Kofl+kjpgEfq11LVAb+nh3LJ5TiKVYVzL3blx1LU42LneQn50ew1EJER2ygkmIGVSX5NDS58LrD1z0Na8/wMG2gaQs4xSpQQK/EDOoLs3BF9C09V+8E9WxM0OMef1J1ZhNpBYJ/ELMoDq0DeNM6Z7wPstXyIxfJCgJ/ELMYG1xDmqWyp66ZifVJTkU5WTEYWRCLJ0EfiFmkJVupTw/66LA7/MHONDqlDJOkdAk8Asxi+oSB43dF5Z0vnF2GJfHLwu3REKTwC/ELKpLcmjuc+GbUtkTzu/LjF8kMgn8QsxiXUkOHl+AjoGxydvqmp2sKcqmxJEZx5EJsTQS+IWYRXVpuLInmO7xBzSvSX5fJAEJ/ELMYl24WVvoAm991zAj4z7J74uEJ4FfiFnkZNhYmZs5OePf3yz5fZEcJPALMYd1pY7JGX9di5PKAjsrcrPiPCohlkYCvxBzWF+SQ1PPKD5/gNdbndKfRyQFCfxCzKG6NIcJX4Cahh4G3V7pzyOSggR+IeawLtSz52f72wBkxi+SggR+IeYQrux5sbGPsrwsKgrscR6REEsngV+IOeRmpVG6LNiMTWb7IllI4BdiHuEWzVLGKZKFBH4h5hFO98jCLZEs4rHZuhAJ5WNXVFKYnU5VoeT3RXKQwC/EPDYsd7BhuSPewxAiaiTVI4QQKUYCvxBCpBgJ/EIIkWIk8AshRIqRwC+EEClGAr8QQqQYCfxCCJFiJPALIUSKUVrreI9hXkqpXqBtkYcXAX1RHE4ikHNODXLOqWEp51yltS6efmNCBP6lUEod0FrviPc4YknOOTXIOacGI85ZUj1CCJFiJPALIUSKSYXA/0C8BxAHcs6pQc45NUT9nJM+xy+EEOJCqTDjF0IIMYUEfiGESDFJHfiVUu9SSp1SSjUppe6K93hiQSnVqpQ6rpQ6opQ6EO/xGEEp9R9KqR6l1IkptxUopZ5VSjWG/s+P5xijbZZz/pZSqjP0Wh9RSr0nnmOMJqVUhVJqn1LqpFLqDaXUnaHbk/Z1nuOco/46J22OXyllBd4E3gGcAV4HPqa1PhnXgRlMKdUK7NBaJ+0iF6XU9cAo8J9a68tCt30HcGqt/yn0Jp+vtf5qPMcZTbOc87eAUa31P8dzbEZQSq0AVmitDymlHMBB4CbgUyTp6zzHOX+YKL/OyTzjvwJo0lo3a609wCPA++M8JhEFWusXAOe0m98P/CT08U8I/sIkjVnOOWlprbu01odCH48A9UAZSfw6z3HOUZfMgb8M6Jjy+RkM+iaajAb+oJQ6qJS6Nd6DiaFSrXVX6ONzQGk8BxNDtymljoVSQUmT9phKKbUK2ArUkSKv87Rzhii/zskc+FPVdVrrbcC7gS+EUgQpRQfzl8mZw7zQD4G1wBagC/hefIcTfUqpHOAx4Ita6+GpX0vW13mGc47665zMgb8TqJjyeXnotqSmte4M/d8D/IZgyisVdIdypOFcaU+cx2M4rXW31tqvtQ4AD5Jkr7VSKo1gAPy51vrx0M1J/TrPdM5GvM7JHPhfB6qVUquVUunAR4En4zwmQymlskMXhVBKZQPvBE7MfVTSeBL4ZOjjTwJPxHEsMREOgCEfIIlea6WUAv4dqNda/8uULyXt6zzbORvxOidtVQ9AqOzp+4AV+A+t9T/GeUiGUkqtITjLB7ABv0jGc1ZKPQzsJNiuthv4W+C3wK+ASoItvD+stU6ai6GznPNOgn/+a6AV+NyU/HdCU0pdB7wIHAcCoZu/RjDnnZSv8xzn/DGi/DondeAXQghxsWRO9QghhJiBBH4hhEgxEviFECLFSOAXQogUI4FfCCFSjAR+kRKUUl8PdTw8FupweGW8xzSdUupGpdTfhTpQ/ne8xyOSly3eAxDCaEqpq4H3Adu01hNKqSIgPc7DmsnbgH2h/1+K81hEEpMZv0gFK4A+rfUEgNa6T2t9FkAptV0p9Xyoqd0zU9oBfFYp9bpS6qhS6jGllD10+58opU6Ebn8hdFumUurHoX0QDiuldoVu/5RS6nGl1NOh/vHfmWlwSqmPKKWOAHcQXHD4IPBppVRSrzQX8SMLuETSCzW9egmwA88Bv9RaPx/qi/I88H6tda9S6iPAjVrrzyilCrXW/aHj/wHo1lrfq5Q6DrxLa92plMrTWg8qpb4EXBo6biPwB2A9wTYh3yTYZXECOEWwiV4H04SW67+stb5GKVUD3BRqzStE1EmqRyQ9rfWoUmo7wRTKLuCXoU08DgCXAc8G4y5Wgt0PAS4LBfw8IAd4JnT7y8BDSqlfAeHGYdcB94aeq0Ep1UYw8APUaK2HAJRSJ4EqLmwXHrYeaA59nC1BXxhJAr9ICVprP1AL1IZm7Z8kuMPRG1rrq2c45CGCs+6jSqlPEeyLg9b6L0IXht8LHAy9ocxlYsrHfmb4nVPBLTKLAFvozWFFKPVzu9b6xYhPUogISY5fJD2l1AalVPWUm7YQbPB1CigOXfxFKZWmlLo0dB8H0BVKB318ymOt1VrXaa2/CfQSbP39Yvg+Sqn1BBuInYp0fFrrHcDvCO4u9R3g61rrLRL0hVFkxi9SQQ5wr1IqD/ABTcCtWmuPUupDwD1KqVyCvw/fB94AvkGwE2Rv6H9H6LG+G3oTUUANcBRoAH4Y+kvCB3wqVD20kDFuI3hx938C/zLPfYVYErm4K4QQKUZSPUIIkWIk8AshRIqRwC+EEClGAr8QQqQYCfxCCJFiJPALIUSKkcAvhBAp5v8DIz4ZWjBpE+oAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = plt.figure()\n",
    "plt.plot(np.arange(len(save_scores)), save_scores)\n",
    "plt.ylabel('Score')\n",
    "plt.xlabel('Season #')\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode: 1, reward: 0.0\n",
      "Episode: 2, reward: 1.0\n",
      "Episode: 3, reward: 0.0\n",
      "Episode: 4, reward: 1.0\n",
      "Episode: 5, reward: 0.0\n",
      "Episode: 6, reward: 1.0\n",
      "Episode: 7, reward: 0.0\n",
      "Episode: 8, reward: 1.0\n",
      "Episode: 9, reward: 1.0\n",
      "Episode: 10, reward: 0.0\n"
     ]
    }
   ],
   "source": [
    "episode = 10\n",
    "scores_window = deque(maxlen=100)  # last 100 scores\n",
    "env = KukaDiverseObjectEnv(renders=False, isDiscrete=False, removeHeightHack=False, maxSteps=20, isTest=True)\n",
    "env.cid = p.connect(p.DIRECT)\n",
    "# load the model\n",
    "checkpoint = torch.load(PATH)\n",
    "policy.load_state_dict(checkpoint['policy_state_dict'])\n",
    "\n",
    "# evaluate the model\n",
    "for e in range(episode):\n",
    "    rewards = eval_policy(envs=env, policy=policy)\n",
    "    reward = np.sum(rewards,0)\n",
    "    print(\"Episode: {0:d}, reward: {1}\".format(e+1, reward), end=\"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
